From 6eba24db212cf02ee88b8a0286eb3f1d5b2e07e9 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 19 Sep 2022 10:47:52 +0200 Subject: [PATCH 01/66] initial implementation of plugabble data writers and buffered writer --- dlt/common/data_writers/__init__.py | 3 + dlt/common/data_writers/buffered.py | 81 ++++++++++++++ dlt/common/data_writers/escape.py | 21 ++++ dlt/common/data_writers/exceptions.py | 11 ++ dlt/common/data_writers/writers.py | 155 ++++++++++++++++++++++++++ dlt/common/dataset_writers.py | 67 ----------- 6 files changed, 271 insertions(+), 67 deletions(-) create mode 100644 dlt/common/data_writers/__init__.py create mode 100644 dlt/common/data_writers/buffered.py create mode 100644 dlt/common/data_writers/escape.py create mode 100644 dlt/common/data_writers/exceptions.py create mode 100644 dlt/common/data_writers/writers.py delete mode 100644 dlt/common/dataset_writers.py diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py new file mode 100644 index 0000000000..89d4607c90 --- /dev/null +++ b/dlt/common/data_writers/__init__.py @@ -0,0 +1,3 @@ +from dlt.common.data_writers.writers import DataWriter, TLoaderFileFormat +from dlt.common.data_writers.buffered import BufferedDataWriter +from dlt.common.data_writers.escape import escape_redshift_literal, escape_redshift_identifier, escape_bigquery_identifier \ No newline at end of file diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py new file mode 100644 index 0000000000..c1c0c3651e --- /dev/null +++ b/dlt/common/data_writers/buffered.py @@ -0,0 +1,81 @@ +from typing import List, IO, Any + +from dlt.common.utils import uniq_id +from dlt.common.typing import TDataItem +from dlt.common.sources import TDirectDataItem +from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.data_writers.exceptions import InvalidFileNameTemplateException +from dlt.common.data_writers.writers import DataWriter +from dlt.common.schema.typing import TTableSchemaColumns + + +class BufferedDataWriter: + def __init__(self, file_format: TLoaderFileFormat, file_name_template: str, buffer_max_items: int = 5000, file_max_bytes: int = None): + self.file_format = file_format + self._file_format_spec = DataWriter.data_format_from_file_format(self.file_format) + # validate if template has correct placeholders + self.file_name_template = file_name_template + self.all_files: List[str] = [] + self.buffer_max_items = buffer_max_items + self.file_max_bytes = file_max_bytes + + self._current_columns: TTableSchemaColumns = None + self._file_name: str = None + self._buffered_items: List[TDataItem] = [] + self._writer: DataWriter = None + self._file: IO[Any] = None + try: + self._rotate_file() + except TypeError: + raise InvalidFileNameTemplateException(file_name_template) + + def write_data_item(self, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: + # rotate file if columns changed and writer does not allow for that + # as the only allowed change is to add new column (no updates/deletes), we detect the change by comparing lengths + if self._writer and not self._writer.data_format().supports_schema_changes and len(columns) != len(self._current_columns): + self._rotate_file() + # until the first chunk is written we can change the columns schema freely + self._current_columns = columns + if isinstance(item, List): + # items coming in single list will be written together, not matter how many are there + self._buffered_items.extend(item) + else: + self._buffered_items.append(item) + # flush if max buffer exceeded + if len(self._buffered_items) > self.buffer_max_items: + self._flush_items() + # rotate the file if max_bytes exceeded + if self.file_max_bytes and self._file and self._file.tell() > self.file_max_bytes: + self._rotate_file() + + def _rotate_file(self) -> None: + self.close_writer() + self._file_name = self.file_name_template % uniq_id() + "." + self._file_format_spec.file_extension + + def _flush_items(self) -> None: + if len(self._buffered_items) > 0: + # we only open a writer when there are any files in the buffer and first flush is requested + if not self._writer: + # create new writer and write header + if self._file_format_spec.is_binary_format: + self._file = open(self._file_name, "wb") + else: + self._file = open(self._file_name, "wt", encoding="utf-8") + self._writer = DataWriter.from_file_format(self.file_format, self._file) + self._writer.write_header(self._current_columns) + # write buffer + self._writer.write_data(self._buffered_items) + self._buffered_items.clear() + + def close_writer(self) -> None: + # if any buffered items exist, flush them + self._flush_items() + # if writer exists then close it + if self._writer: + # write the footer of a file + self._writer.write_footer() + # add file written to the list so we can commit all the files later + self.all_files.append(self._file_name) + self._file.close() + self._writer = None + self._file = None diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py new file mode 100644 index 0000000000..a8cef5e31d --- /dev/null +++ b/dlt/common/data_writers/escape.py @@ -0,0 +1,21 @@ +import re + +# use regex to escape characters in single pass +SQL_ESCAPE_DICT = {"'": "''", "\\": "\\\\", "\n": "\\n", "\r": "\\r"} +SQL_ESCAPE_RE = re.compile("|".join([re.escape(k) for k in sorted(SQL_ESCAPE_DICT, key=len, reverse=True)]), flags=re.DOTALL) + + +def escape_redshift_literal(v: str) -> str: + # https://www.postgresql.org/docs/9.3/sql-syntax-lexical.html + # looks like this is the only thing we need to escape for Postgres > 9.1 + # redshift keeps \ as escape character which is pre 9 behavior + return "{}{}{}".format("'", SQL_ESCAPE_RE.sub(lambda x: SQL_ESCAPE_DICT[x.group(0)], v), "'") + + +def escape_redshift_identifier(v: str) -> str: + return '"' + v.replace('"', '""').replace("\\", "\\\\") + '"' + + +def escape_bigquery_identifier(v: str) -> str: + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical + return "`" + v.replace("\\", "\\\\").replace("`","\\`") + "`" diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py new file mode 100644 index 0000000000..4f249eb142 --- /dev/null +++ b/dlt/common/data_writers/exceptions.py @@ -0,0 +1,11 @@ +from dlt.common.exceptions import DltException + + +class DataWriterException(DltException): + pass + + +class InvalidFileNameTemplateException(DataWriterException, ValueError): + def __init__(self, file_name_template: str): + self.file_name_template = file_name_template + super().__init__(f"Wrong file name template {file_name_template}. File name template must contain exactly one %s formatter") diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py new file mode 100644 index 0000000000..cc1a8fe212 --- /dev/null +++ b/dlt/common/data_writers/writers.py @@ -0,0 +1,155 @@ +import abc +import jsonlines +from dataclasses import dataclass +from typing import Any, Dict, Sequence, IO, Literal, Type +from datetime import date, datetime # noqa: I251 + +from dlt.common import json +from dlt.common.typing import StrAny +from dlt.common.json import json_typed_dumps +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal + +TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values"] + + +@dataclass +class TFileFormatSpec: + file_format: TLoaderFileFormat + file_extension: str + is_binary_format: bool + supports_schema_changes: bool + + +class DataWriter(abc.ABC): + def __init__(self, f: IO[Any]) -> None: + self._f = f + + @abc.abstractmethod + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + pass + + @abc.abstractmethod + def write_data(self, rows: Sequence[Any]) -> None: + pass + + @abc.abstractmethod + def write_footer(self) -> None: + pass + + def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> None: + self.write_header(columns_schema) + self.write_data(rows) + self.write_footer() + + + @classmethod + @abc.abstractmethod + def data_format(cls) -> TFileFormatSpec: + pass + + @classmethod + def from_file_format(cls, file_format: TLoaderFileFormat, f: IO[Any]) -> "DataWriter": + return cls.class_factory(file_format)(f) + + @classmethod + def data_format_from_file_format(cls, file_format: TLoaderFileFormat) -> TFileFormatSpec: + return cls.class_factory(file_format).data_format() + + @staticmethod + def class_factory(file_format: TLoaderFileFormat) -> Type["DataWriter"]: + if file_format == "jsonl": + return JsonlWriter + elif file_format == "puae-jsonl": + return JsonlPUAEncodeWriter + elif file_format == "insert_values": + return InsertValuesWriter + else: + raise ValueError(file_format) + + +class JsonlWriter(DataWriter): + + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + pass + + def write_data(self, rows: Sequence[Any]) -> None: + # use jsonl to write load files https://jsonlines.org/ + with jsonlines.Writer(self._f, dumps=json.dumps) as w: + w.write_all(rows) + + def write_footer(self) -> None: + pass + + @classmethod + def data_format(cls) -> TFileFormatSpec: + return TFileFormatSpec("jsonl", "jsonl", False, True) + + +class JsonlPUAEncodeWriter(JsonlWriter): + + def write_data(self, rows: Sequence[Any]) -> None: + # encode types with PUA characters + with jsonlines.Writer(self._f, dumps=json_typed_dumps) as w: + w.write_all(rows) + + @classmethod + def data_format(cls) -> TFileFormatSpec: + return TFileFormatSpec("puae-jsonl", "jsonl", False, True) + + +class InsertValuesWriter(DataWriter): + + def __init__(self, f: IO[Any]) -> None: + super().__init__(f) + self._chunks_written = 0 + self._headers_lookup: Dict[str, int] = None + + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + assert self._chunks_written == 0 + headers = columns_schema.keys() + # dict lookup is always faster + self._headers_lookup = {v: i for i, v in enumerate(headers)} + # do not write INSERT INTO command, this must be added together with table name by the loader + self._f.write("INSERT INTO {}(") + self._f.write(",".join(map(escape_redshift_identifier, headers))) + self._f.write(")\nVALUES\n") + + def write_data(self, rows: Sequence[Any]) -> None: + + def stringify(v: Any) -> str: + if isinstance(v, bytes): + return f"from_hex('{v.hex()}')" + if isinstance(v, (datetime, date)): + return escape_redshift_literal(v.isoformat()) + else: + return str(v) + + def write_row(row: StrAny) -> None: + output = ["NULL"] * len(self._headers_lookup) + for n,v in row.items(): + output[self._headers_lookup[n]] = escape_redshift_literal(v) if isinstance(v, str) else stringify(v) + self._f.write("(") + self._f.write(",".join(output)) + self._f.write(")") + + # if next chunk add separator + if self._chunks_written > 0: + self._f.write(",\n") + + # write rows + for row in rows[:-1]: + write_row(row) + self._f.write(",\n") + + # write last row without separator so we can write footer eventually + write_row(rows[-1]) + self._chunks_written += 1 + + def write_footer(self) -> None: + assert self._chunks_written > 0 + self._f.write(";") + + @classmethod + def data_format(cls) -> TFileFormatSpec: + return TFileFormatSpec("insert_values", "insert_values", False, False) diff --git a/dlt/common/dataset_writers.py b/dlt/common/dataset_writers.py deleted file mode 100644 index 67e1ea130d..0000000000 --- a/dlt/common/dataset_writers.py +++ /dev/null @@ -1,67 +0,0 @@ -import re -import jsonlines -from datetime import date, datetime # noqa: I251 -from typing import Any, Iterable, Literal, Sequence, IO - -from dlt.common import json -from dlt.common.typing import StrAny - -TLoaderFileFormat = Literal["jsonl", "insert_values"] - -# use regex to escape characters in single pass -SQL_ESCAPE_DICT = {"'": "''", "\\": "\\\\", "\n": "\\n", "\r": "\\r"} -SQL_ESCAPE_RE = re.compile("|".join([re.escape(k) for k in sorted(SQL_ESCAPE_DICT, key=len, reverse=True)]), flags=re.DOTALL) - - -def write_jsonl(f: IO[Any], rows: Sequence[Any]) -> None: - # use jsonl to write load files https://jsonlines.org/ - with jsonlines.Writer(f, dumps=json.dumps) as w: - w.write_all(rows) - - -def write_insert_values(f: IO[Any], rows: Sequence[StrAny], headers: Iterable[str]) -> None: - # dict lookup is always faster - headers_lookup = {v: i for i, v in enumerate(headers)} - # do not write INSERT INTO command, this must be added together with table name by the loader - f.write("INSERT INTO {}(") - f.write(",".join(map(escape_redshift_identifier, headers))) - f.write(")\nVALUES\n") - - def stringify(v: Any) -> str: - if isinstance(v, bytes): - return f"from_hex('{v.hex()}')" - if isinstance(v, (datetime, date)): - return escape_redshift_literal(v.isoformat()) - else: - return str(v) - - def write_row(row: StrAny) -> None: - output = ["NULL" for _ in range(len(headers_lookup))] - for n,v in row.items(): - output[headers_lookup[n]] = escape_redshift_literal(v) if isinstance(v, str) else stringify(v) - f.write("(") - f.write(",".join(output)) - f.write(")") - - for row in rows[:-1]: - write_row(row) - f.write(",\n") - - write_row(rows[-1]) - f.write(";") - - -def escape_redshift_literal(v: str) -> str: - # https://www.postgresql.org/docs/9.3/sql-syntax-lexical.html - # looks like this is the only thing we need to escape for Postgres > 9.1 - # redshift keeps \ as escape character which is pre 9 behavior - return "{}{}{}".format("'", SQL_ESCAPE_RE.sub(lambda x: SQL_ESCAPE_DICT[x.group(0)], v), "'") - - -def escape_redshift_identifier(v: str) -> str: - return '"' + v.replace('"', '""').replace("\\", "\\\\") + '"' - - -def escape_bigquery_identifier(v: str) -> str: - # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical - return "`" + v.replace("\\", "\\\\").replace("`","\\`") + "`" From 47e47661595ee2e0a17b120fee0e9fc8c8e70bbc Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 19 Sep 2022 10:48:51 +0200 Subject: [PATCH 02/66] experimental implementation of extraction pipe, source, resources and associated typing --- experiments/pipeline/async_decorator.py | 585 ------------------------ experiments/pipeline/extract.py | 112 +++++ experiments/pipeline/pipe.py | 377 +++++++++++++++ experiments/pipeline/pipeline.py | 12 +- experiments/pipeline/sources.py | 161 ++++--- experiments/pipeline/typing.py | 12 +- 6 files changed, 612 insertions(+), 647 deletions(-) delete mode 100644 experiments/pipeline/async_decorator.py create mode 100644 experiments/pipeline/extract.py create mode 100644 experiments/pipeline/pipe.py diff --git a/experiments/pipeline/async_decorator.py b/experiments/pipeline/async_decorator.py deleted file mode 100644 index b8fd4a3997..0000000000 --- a/experiments/pipeline/async_decorator.py +++ /dev/null @@ -1,585 +0,0 @@ -import asyncio -from collections import abc -from copy import deepcopy -from functools import wraps -from itertools import chain - -import inspect -import itertools -import os -import sys -from typing import Any, Coroutine, Dict, Iterator, List, NamedTuple, Sequence, cast - -from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TColumnSchema, TTableSchema, TTableSchemaColumns -from dlt.common.schema.utils import new_table -from dlt.common.sources import with_retry, with_table_name, get_table_name - -# from examples.sources.rasa_tracker_store - - -_meta = {} - -_i_schema: Schema = None -_i_info = None - -#abc.Iterator - - -class TableMetadataMixin: - def __init__(self, table: TTableSchema, schema: Schema = None): - self._table = table - self._schema = schema - self._table_name = table["name"] - self.__name__ = self._table_name - - @property - def table_schema(self): - # TODO: returns unified table schema by merging _schema and _table with table taking precedence - return self._table - - -class TableIterable(abc.Iterable, TableMetadataMixin): - def __init__(self, i, table, schema = None): - self._data = i - super().__init__(table, schema) - - def __iter__(self): - # TODO: this should resolve the _data like we do in the extract method: all awaitables and deferred items are resolved - # possibly in parallel. - if isinstance(self._data, abc.Iterator): - return TableIterator(self._data, self._table, self._schema) - return TableIterator(iter(self._data), self._table, self._schema) - - -class TableIterator(abc.Iterator, TableMetadataMixin): - def __init__(self, i, table, schema = None): - self.i = i - super().__init__(table, schema) - - def __next__(self): - # export metadata to global variable so it can be read by extractor - # TODO: remove this hack if possible - global _i_info - _i_info = cast(self, TableMetadataMixin) - - return next(self.i) - - def __iter__(self): - return self - - -class TableGenerator(abc.Generator, TableMetadataMixin): - def __init__(self, g, table, schema = None): - self.g = g - super().__init__(table, schema) - - def send(self, value): - return self.g.send(value) - - def throw(self, typ, val=None, tb=None): - return self.g.throw(typ, val, tb) - - -class SourceList(abc.Sequence): - def __init__(self, s, schema): - self.s: abc.Sequence = s - self.schema = schema - - # Sized - def __len__(self) -> int: - return self.s.__len__() - - # Iterator - def __next__(self): - return next(self.s) - - def __iter__(self): - return self - - # Container - def __contains__(self, value: object) -> bool: - return self.s.__contains__(value) - - # Reversible - def __reversed__(self): - return self.s.__reversed__() - - # Sequence - def __getitem__(self, index): - return self.s.__getitem__(index) - - def index(self, value: Any, start: int = ..., stop: int = ...) -> int: - return self.s.index(value, start, stop) - - def count(self, value: Any) -> int: - return self.s.count(value) - -class SourceTable(NamedTuple): - table_name: str - data: Iterator[Any] - - -def source(schema=None): - """This is source decorator""" - def _dec(f: callable): - print(f"calling source on {f.__name__}") - global _i_schema - - __dlt_schema = Schema(f.__name__) if not schema else schema - sig = inspect.signature(f) - - @wraps(f) - def _wrap(*args, **kwargs): - global _i_schema - - inner_schema: Schema = None - # if "schema" in kwargs and isinstance(kwargs["schema"], Schema): - # inner_schema = kwargs["schema"] - # # remove if not in sig - # if "schema" not in sig.parameters: - # del kwargs["schema"] - - _i_schema = inner_schema or __dlt_schema - rv = f(*args, **kwargs) - - if not isinstance(rv, (abc.Iterator, abc.Iterable)) or isinstance(rv, (dict, str)): - raise ValueError(f"Expected iterator/iterable containing tables {type(rv)}") - - # assume that source contain iterator of TableIterable - tables = [] - for table in rv: - # if not isinstance(rv, abc.Iterator) or isinstance(rv, (dict, str): - if not isinstance(table, TableIterable): - raise ValueError(f"Please use @table or as_table: {type(table)}") - tables.append(table) - # iterator consumed - clear schema - _i_schema = None - # if hasattr(rv, "__name__"): - # s.a - # source with single table - # return SourceList([rv], _i_schema) - # elif isinstance(rv, abc.Sequence): - # # peek what is inside - # item = None if len(rv) == 0 else rv[1] - # # if this is list, iterator or empty - # if isinstance(item, (NoneType, TableMetadataMixin, abc.Iterator)): - # return SourceList(rv, _i_schema) - # else: - # return SourceList([rv], _i_schema) - # else: - # raise ValueError(f"Unsupported table type {type(rv)}") - - return tables - # if isinstance(rv, abc.Iterable) or inspect(rv, abc.Iterator): - # yield from rv - # else: - # yield rv - print(f.__doc__) - _wrap.__doc__ = f.__doc__ + """This is source decorator""" - return _wrap - - # if isinstance(obj, callable): - # return _wrap - # else: - # return obj - return _dec - - -def table(name = None, write_disposition = None, parent = None, columns: Sequence[TColumnSchema] = None, schema = None): - def _dec(f: callable): - global _i_schema - - if _i_schema and schema: - raise Exception("Do not set explicit schema for a table in source context") - - l_schema = schema or _i_schema - table = new_table(name or f.__name__, parent, write_disposition, columns) - print(f"calling TABLE on {f.__name__}: {l_schema}") - - # @wraps(f, updated=('__dict__','__doc__')) - def _wrap(*args, **kwargs): - rv = f(*args, **kwargs) - return TableIterable(rv, table, l_schema) - # assert _i_info == None - - # def _yield_inner() : - # global _i_info - # print(f"TABLE: setting _i_info on {f.__name__} {l_schema}") - # _i_info = (table, l_schema) - - # if isinstance(rv, abc.Sequence): - # yield rv - # # return TableIterator(iter(rv), _i_info) - # elif isinstance(rv, abc.Generator): - # # return TableGenerator(rv, _i_info) - # yield from rv - # else: - # yield from rv - # _i_info = None - # # must clean up in extract - # # assert _i_info == None - - # gen_inner = _yield_inner() - # # generator name is a table name - # gen_inner.__name__ = "__dlt_meta:" + "*" if callable(table["name"]) else table["name"] - # # return generator - # return gen_inner - - # _i_info = None - # yield from map(lambda i: with_table_name(i, id(rv)), rv) - - return _wrap - - return _dec - - -def as_table(obj, name, write_disposition = None, parent = None, columns = None): - global _i_schema - l_schema = _i_schema - - # for i, f in sys._current_frames(): - # print(i, f) - - # print(sys._current_frames()) - - # try: - # for d in range(0, 10): - # c_f = sys._getframe(d) - # print(c_f.f_code.co_varnames) - # print("------------") - # if "__dlt_schema" in c_f.f_locals: - # l_schema = c_f.f_locals["__dlt_schema"] - # break - # except ValueError: - # # stack too deep - # pass - - # def inner(): - # # global _i_info - - # # assert _i_info == None - # print(f"AS_TABLE: setting _i_info on {name} {l_schema}") - # table = new_table(name, parent, write_disposition, columns) - # _i_info = (table, l_schema) - # if isinstance(obj, abc.Sequence): - # return TableIterator(iter(obj), _i_info) - # elif isinstance(obj, abc.Generator): - # return TableGenerator(obj, _i_info) - # else: - # return TableIterator(obj, _i_info) - # # if isinstance(obj, abc.Sequence): - # # yield obj - # # else: - # # yield from obj - # # _i_info = None - - table = new_table(name, parent, write_disposition, columns) - print(f"calling AS TABLE on {name}: {l_schema}") - return TableIterable(obj, table, l_schema) - # def _yield_inner() : - # global _i_info - # print(f"AS_TABLE: setting _i_info on {name} {l_schema}") - # _i_info = (table, l_schema) - - # if isinstance(obj, abc.Sequence): - # yield obj - # # return TableIterator(iter(obj), _i_info) - # elif isinstance(obj, abc.Generator): - # # return TableGenerator(obj, _i_info) - # yield from obj - # else: - # yield from obj - - # # must clean up in extract - # # assert _i_info == None - - # gen_inner = _yield_inner() - # # generator name is a table name - # gen_inner.__name__ = "__dlt_meta:" + "*" if callable(table["name"]) else table["name"] - # # return generator - # return gen_inner - - # return inner() - -# def async_table(write_disposition = None, parent = None, columns = None): - -# def _dec(f: callable): - -# def _wrap(*args, **kwargs): -# global _i_info - -# l_info = new_table(f.__name__, parent, write_disposition, columns) -# rv = f(*args, **kwargs) - -# for i in rv: -# # assert _i_info == None -# # print("set info") -# _i_info = l_info -# # print(f"what: {i}") -# yield i -# _i_info = None -# # print("reset info") - -# # else: -# # yield from rv -# # yield from map(lambda i: with_table_name(i, id(rv)), rv) - -# return _wrap - -# return _dec - - -# takes name from decorated function -@source() -def spotify(api_key=None): - """This is spotify source with several tables""" - - # takes name from decorated function - @table(write_disposition="append") - def numbers(): - return [1, 2, 3, 4] - - @table(write_disposition="replace") - def songs(library): - - # https://github.com/leshchenko1979/reretry - async def _i(id): - await asyncio.sleep(0.5) - # raise Exception("song cannot be taken") - return {f"song{id}": library} - - for i in range(3): - yield _i(i) - - @table(write_disposition="replace") - def albums(library): - - async def _i(id): - await asyncio.sleep(0.5) - return {f"album_{id}": library} - - - for i in ["X", "Y"]: - yield _i(i) - - @table(write_disposition="append") - def history(): - """This is your song history""" - print("HISTORY yield") - yield {"name": "dupa"} - - print("spotify returns list") - return ( - history(), - numbers(), - songs("lib_1"), - as_table(["lib_song"], name="library"), - albums("lib_2") - ) - - -@source() -def annotations(): - """Ad hoc annotation source""" - yield as_table(["ann1", "ann2", "ann3"], "annotate", write_disposition="replace") - - -# this table exists out of source context and will attach itself to the current default schema in the pipeline -@table(write_disposition="merge", parent="songs_authors") -def songs__copies(song, num): - return [{"song": song, "copy": True}] * num - - -event_column_template: List[TColumnSchema] = [{ - "name": "timestamp", - "data_type": "timestamp", - "nullable": False, - "partition": True, - "sorted": True - } -] - -# this demonstrates the content based naming of the tables for stream based sources -# same rule that applies to `name` could apply to `write_disposition` and `columns` -@table(name=lambda i: "event_" + i["event"], write_disposition="append", columns=event_column_template) -def events(): - from examples.sources.jsonl import get_source as read_jsonl - - sources = [ - read_jsonl(file) for file in os.scandir("examples/data/rasa_trackers") - ] - for i in chain(*sources): - yield { - "sender_id": i["sender_id"], - "timestamp": i["timestamp"], - "event": i["event"] - } - yield i - - -# another standalone source -authors = as_table(["authr"], "songs_authors") - - -# def source_with_schema_discovery(credentials, sheet_id, tab_id): - -# # discover the schema from actual API -# schema: Schema = get_schema_from_sheets(credentials, sheet_id) - -# # now declare the source -# @source(schema=schema) -# @table(name=schema.schema_name, write_disposition="replace") -# def sheet(): -# from examples.sources.google_sheets import get_source - -# yield from get_source(credentials, sheet_id, tab_id) - -# return sheet() - - -class Pipeline: - def __init__(self, parallelism = 2, default_schema: Schema = None) -> None: - self.sem = asyncio.Semaphore(parallelism) - self.schemas: Dict[str, Schema] = {} - self.default_schema_name: str = "" - if default_schema: - self.default_schema_name = default_schema.name - self.schemas[default_schema.name] = default_schema - - async def extract(self, items, schema: Schema = None): - # global _i_info - - l_info = None - if isinstance(items, TableIterable): - l_info = (items._table, items._schema) - print(f"extracting table with name {getattr(items, '__name__', None)} {l_info}") - - # if id(i) in meta: - # print(meta[id(i)]) - - def _p_i(item, what): - if l_info: - info_schema: Schema = l_info[1] - if info_schema: - # if already in pipeline - use the pipeline one - info_schema = self.schemas.get(info_schema.name) or info_schema - # if explicit - use explicit - eff_schema = schema or info_schema - if eff_schema is None: - # create default schema when needed - eff_schema = self.schemas.get(self.default_schema_name) or Schema(self.default_schema_name) - if eff_schema is not None: - table: TTableSchema = l_info[0] - # normalize table name - if callable(table["name"]): - table_name = table["name"](item) - else: - table_name = eff_schema.normalize_table_name(table["name"]) - - if table_name not in eff_schema._schema_tables: - table = deepcopy(table) - table["name"] = table_name - # TODO: normalize all other names - eff_schema.update_schema(table) - # TODO: l_info may contain type hints etc. - self.schemas[eff_schema.name] = eff_schema - if len(self.schemas) == 1: - self.default_schema_name = eff_schema.name - - print(f"{item} of {what} has HINT and will be written as {eff_schema.name}:{table_name}") - else: - eff_schema = self.schemas.get(self.default_schema_name) or Schema(self.default_schema_name) - print(f"{item} of {what} No HINT and will be written as {eff_schema.name}:table") - - # l_info = _i_info - if isinstance(items, TableIterable): - items = iter(items._data) - if isinstance(items, (abc.Sequence)): - items = iter(items) - # for item in items: - # _p_i(item, "list_item") - if inspect.isasyncgen(items): - raise NotImplemented() - else: - # context is set immediately - item = next(items, None) - if item is None: - return - - global _i_info - - if l_info is None and isinstance(_i_info, TableMetadataMixin): - l_info = (_i_info._table, _i_info._schema) - # l_info = _i_info - # _i_info = None - if inspect.iscoroutine(item) or isinstance(item, Coroutine): - async def _await(a_i): - async with self.sem: - # print("enter await") - a_itm = await a_i - _p_i(a_itm, "awaitable") - - items = await asyncio.gather( - asyncio.ensure_future(_await(item)), *(asyncio.ensure_future(_await(ii)) for ii in items) - ) - else: - _p_i(item, "iterator") - list(map(_p_i, items, itertools.repeat("iterator"))) - - # print("reset info") - # _i_info = None - # assert _i_info is None - - def extract_all(self, sources, schema: Schema = None): - loop = asyncio.get_event_loop() - loop.run_until_complete(asyncio.gather(*[self.extract(i, schema=schema) for i in sources])) - # loop.close() - - -default_schema = Schema("") - -print("s iter of iters") -s = spotify(api_key="7") -for items in s: - print(items.__name__) -s[0] = map(lambda d: {**d, **{"lambda": True}} , s[0]) -print("s2 iter of iters") -s2 = annotations() - -# for x in s[0]: -# print(x) -# exit() - -# print(list(s2)) -# exit(0) - -# mod albums - -def mapper(a): - a["mapper"] = True - return a - -# https://asyncstdlib.readthedocs.io/en/latest/# -# s[3] = map(mapper, s[3]) - - -# Pipeline().extract_all([s2]) -p = Pipeline(default_schema=Schema("default")) -chained = chain(s, s2, [authors], [songs__copies("abba", 4)], [["raw", "raw"]]) -# for items in chained: -# print(f"{type(items)}: {getattr(items, '__name__', 'NONE')}") -p.extract_all(chained, schema=None) -# p.extract_all([events()], schema=Schema("events")) -p.extract_all([["nein"] * 5]) -p.extract_all([as_table([], name="EMPTY")]) - - -for schema in p.schemas.values(): - print(schema.to_pretty_yaml(remove_defaults=True)) - -print(p.default_schema_name) -# for i in s: -# await extract(i) -# for i in s: -# extract(chain(*i, iter([1, "zeta"]))) \ No newline at end of file diff --git a/experiments/pipeline/extract.py b/experiments/pipeline/extract.py new file mode 100644 index 0000000000..076562c4d1 --- /dev/null +++ b/experiments/pipeline/extract.py @@ -0,0 +1,112 @@ +import os +from typing import Dict, List, Sequence, Type +from typing_extensions import reveal_type +from dlt.common.schema.typing import TTableSchemaColumns + +from dlt.common.utils import uniq_id +from dlt.common.sources import TDirectDataItem, TDataItem +from dlt.common.schema import utils, TSchemaUpdate +from dlt.common.data_writers import BufferedDataWriter +from dlt.common.storages import NormalizeStorage +from dlt.common.configuration import NormalizeVolumeConfiguration + + +from experiments.pipeline.pipe import PipeIterator +from experiments.pipeline.sources import DltResource, DltSource + + +class ExtractorStorage(NormalizeStorage): + EXTRACT_FOLDER = "extract" + EXTRACT_FILE_NAME_TEMPLATE = "" + + def __init__(self, C: Type[NormalizeVolumeConfiguration]) -> None: + super().__init__(False, C) + self.initialize_storage() + self.buffered_writers: Dict[str, BufferedDataWriter] = {} + + def initialize_storage(self) -> None: + self.storage.create_folder(ExtractorStorage.EXTRACT_FOLDER, exists_ok=True) + + def create_extract_temp_folder(self) -> str: + tf_name = uniq_id() + self.storage.create_folder(os.path.join(ExtractorStorage.EXTRACT_FOLDER, tf_name)) + return tf_name + + def commit_extract_files(self, temp_folder_name: str, with_delete: bool = True) -> None: + temp_path = os.path.join(os.path.join(ExtractorStorage.EXTRACT_FOLDER, temp_folder_name)) + for file in self.storage.list_folder_files(temp_path, to_root=False): + from_file = os.path.join(temp_path, file) + to_file = os.path.join(NormalizeStorage.EXTRACTED_FOLDER, file) + if with_delete: + self.storage.atomic_rename(from_file, to_file) + else: + # create hardlink which will act as a copy + self.storage.link_hard(from_file, to_file) + if with_delete: + self.storage.delete_folder(temp_path, recursively=True) + + def write_data_item(self, schema_name: str, table_name: str, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: + # unique writer id + writer_id = f"{schema_name}.{table_name}" + writer = self.buffered_writers.get(writer_id, None) + if not writer_id: + # assign a jsonl writer with pua encoding for each table, use %s for file id to create required template + writer = BufferedDataWriter("puae-jsonl", NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s")) + self.buffered_writers[writer_id] = writer + # write item(s) + writer.write_data_item(item, columns) + + def close_writers(self) -> None: + # flush and close all files + for writer in self.buffered_writers.values(): + writer.close_writer() + + +def extract(source: DltSource, storage: ExtractorStorage) -> TSchemaUpdate: + dynamic_tables: TSchemaUpdate = {} + schema = source.schema + + def _write_item(table_name: str, item: TDirectDataItem) -> None: + # normalize table name before writing so the name match the name in schema + # note: normalize function should be cached so there's almost no penalty on frequent calling + # note: column schema is not required for jsonl writer used here + storage.write_data_item(schema.name, schema.normalize_table_name(table_name), item, None) + + def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: + table_name = resource.table_name_hint_fun(item) + existing_table = dynamic_tables.get(table_name) + if existing_table is None: + dynamic_tables[table_name] = [resource.table_schema(item)] + else: + # quick check if deep table merge is required + if resource.table_has_other_dynamic_props: + new_table = resource.table_schema(item) + # this merges into existing table in place + utils.merge_tables(existing_table[0], new_table) + else: + # if there are no other dynamic hints besides name then we just leave the existing partial table + pass + # write to storage with inferred table name + _write_item(table_name, item) + + + # yield from all selected pipes + for pipe_item in PipeIterator.from_pipes(source.pipes): + # get partial table from table template + resource = source[pipe_item.pipe.name] + if resource.table_name_hint_fun: + if isinstance(pipe_item.item, List): + for item in pipe_item.item: + _write_dynamic_table(resource, item) + else: + _write_dynamic_table(resource, pipe_item.item) + else: + # write item belonging to table with static name + _write_item(resource.name, pipe_item.item) + + # flush all buffered writers + storage.close_writers() + + # returns set of partial tables + return dynamic_tables + diff --git a/experiments/pipeline/pipe.py b/experiments/pipeline/pipe.py new file mode 100644 index 0000000000..437336b6c1 --- /dev/null +++ b/experiments/pipeline/pipe.py @@ -0,0 +1,377 @@ +import types +import asyncio +from asyncio import Future +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from threading import Thread +from typing import Optional, Sequence, Union, Callable, Iterable, Iterator, List, NamedTuple, Awaitable, Tuple, Type, TYPE_CHECKING + +if TYPE_CHECKING: + TItemFuture = Future[TDirectDataItem] +else: + TItemFuture = Future + +from dlt.common.exceptions import DltException +from dlt.common.time import sleep +from dlt.common.sources import TDirectDataItem, TResolvableDataItem + + +class PipeItem(NamedTuple): + item: TDirectDataItem + step: int + pipe: "Pipe" + + +class ResolvablePipeItem(NamedTuple): + # mypy unable to handle recursive types, ResolvablePipeItem should take itself in "item" + item: Union[TResolvableDataItem, Iterator[TResolvableDataItem]] + step: int + pipe: "Pipe" + + +class FuturePipeItem(NamedTuple): + item: TItemFuture + step: int + pipe: "Pipe" + + +class SourcePipeItem(NamedTuple): + item: Union[Iterator[TResolvableDataItem], Iterator[ResolvablePipeItem]] + step: int + pipe: "Pipe" + + +# pipeline step may be iterator of data items or mapping function that returns data item or another iterator +TPipeStep = Union[ + Iterable[TResolvableDataItem], + Iterator[TResolvableDataItem], + Callable[[TDirectDataItem], TResolvableDataItem], + Callable[[TDirectDataItem], Iterator[TResolvableDataItem]], + Callable[[TDirectDataItem], Iterator[ResolvablePipeItem]] +] + + +class ForkPipe: + def __init__(self, pipe: "Pipe", step: int = -1) -> None: + self._pipes: List[Tuple["Pipe", int]] = [] + self.add_pipe(pipe, step) + + def add_pipe(self, pipe: "Pipe", step: int = -1) -> None: + if pipe not in self._pipes: + self._pipes.append((pipe, step)) + + def has_pipe(self, pipe: "Pipe") -> bool: + return pipe in [p[0] for p in self._pipes] + + def __call__(self, item: TDirectDataItem) -> Iterator[ResolvablePipeItem]: + for i, (pipe, step) in enumerate(self._pipes): + _it = item if i == 0 else deepcopy(item) + # always start at the beginning + yield ResolvablePipeItem(_it, step, pipe) + + +class FilterItem: + def __init__(self, filter_f: Callable[[TDirectDataItem], bool]) -> None: + self._filter_f = filter_f + + def __call__(self, item: TDirectDataItem) -> Optional[TDirectDataItem]: + # item may be a list TDataItem or a single TDataItem + if isinstance(item, list): + item = [i for i in item if self._filter_f(i)] + if not item: + # item was fully consumed by the filer + return None + return item + else: + return item if self._filter_f(item) else None + + +class Pipe: + def __init__(self, name: str, steps: List[TPipeStep] = None, depends_on: "Pipe" = None) -> None: + self.name = name + self._steps: List[TPipeStep] = steps or [] + self.depends_on = depends_on + + @classmethod + def from_iterable(cls, name: str, gen: Union[Iterable[TResolvableDataItem], Iterator[TResolvableDataItem]]) -> "Pipe": + if isinstance(gen, Iterable): + gen = iter(gen) + return cls(name, [gen]) + + @property + def head(self) -> TPipeStep: + return self._steps[0] + + @property + def tail(self) -> TPipeStep: + return self._steps[-1] + + @property + def steps(self) -> List[TPipeStep]: + return self._steps + + def __getitem__(self, i: int) -> TPipeStep: + return self._steps[i] + + def __len__(self) -> int: + return len(self._steps) + + def fork(self, child_pipe: "Pipe", child_step: int = -1) -> "Pipe": + if len(self._steps) == 0: + raise CreatePipeException("Cannot fork to empty pipe") + fork_step = self.tail + if not isinstance(fork_step, ForkPipe): + fork_step = ForkPipe(child_pipe, child_step) + self.add_step(fork_step) + else: + if not fork_step.has_pipe(child_pipe): + fork_step.add_pipe(child_pipe, child_step) + return self + + def clone(self) -> "Pipe": + return Pipe(self.name, self._steps.copy(), self.depends_on) + + def add_step(self, step: TPipeStep) -> "Pipe": + if len(self._steps) == 0 and self.depends_on is None: + # first element must be iterable or iterator + if not isinstance(step, (Iterable, Iterator)): + raise CreatePipeException("First step of independent pipe must be Iterable or Iterator") + else: + if isinstance(step, Iterable): + step = iter(step) + self._steps.append(step) + else: + if isinstance(step, (Iterable, Iterator)): + if self.depends_on is not None: + raise CreatePipeException("Iterable or Iterator cannot be a step in dependent pipe") + else: + raise CreatePipeException("Iterable or Iterator can only be a first step in independent pipe") + if not callable(step): + raise CreatePipeException("Pipe step must be a callable taking exactly one data item as input") + self._steps.append(step) + return self + + def full_pipe(self) -> "Pipe": + if self.depends_on: + pipe = self.depends_on.full_pipe().steps + else: + pipe = [] + + # return pipe with resolved dependencies + pipe.extend(self._steps) + return Pipe(self.name, pipe) + + +class PipeIterator(Iterator[PipeItem]): + + def __init__(self, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> None: + self.max_parallelism = max_parallelism + self.worker_threads = worker_threads + self.futures_poll_interval = futures_poll_interval + + self._async_pool: asyncio.AbstractEventLoop = None + self._async_pool_thread: Thread = None + self._thread_pool: ThreadPoolExecutor = None + self._sources: List[SourcePipeItem] = [] + self._futures: List[FuturePipeItem] = [] + + @classmethod + def from_pipe(cls, pipe: Pipe, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + if pipe.depends_on: + pipe = pipe.full_pipe() + # head must be iterator + assert isinstance(pipe.head, Iterator) + # create extractor + extract = cls(max_parallelism, worker_threads, futures_poll_interval) + # add as first source + extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) + return extract + + @classmethod + def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + # as we add fork steps below, pipes are cloned before use + pipes = [p.clone() for p in pipes] + extract = cls(max_parallelism, worker_threads, futures_poll_interval) + for pipe in reversed(pipes): + if pipe.depends_on: + # fork the parent pipe + pipe.depends_on.fork(pipe) + # make the parent yield by sending a clone of item to itself with position at the end + if yield_parents: + pipe.depends_on.fork(pipe.depends_on, len(pipe.depends_on) - 1) + else: + # head of independent pipe must be iterator + assert isinstance(pipe.head, Iterator) + # add every head as source + extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) + return extract + + def __next__(self) -> PipeItem: + pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None + # __next__ should call itself to remove the `while` loop and continue clauses but that may lead to stack overflows: there's no tail recursion opt in python + # https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion (see Y combinator on how it could be emulated) + while True: + # do we need new item? + if pipe_item is None: + # process element from the futures + if len(self._futures) > 0: + pipe_item = self._resolve_futures() + # if none then take element from the newest source + if pipe_item is None: + pipe_item = self._get_source_item() + + if pipe_item is None: + if len(self._futures) == 0 and len(self._sources) == 0: + # no more elements in futures or sources + raise StopIteration() + else: + # if len(_sources + # print("waiting") + sleep(self.futures_poll_interval) + continue + + # if item is iterator, then add it as a new source + if isinstance(pipe_item.item, Iterator): + # print(f"adding iterable {item}") + self._sources.append(SourcePipeItem(pipe_item.item, pipe_item.step, pipe_item.pipe)) + pipe_item = None + continue + + if isinstance(pipe_item.item, Awaitable) or callable(pipe_item.item): + # do we have a free slot or one of the slots is done? + if len(self._futures) < self.max_parallelism or self._next_future() >= 0: + if isinstance(pipe_item.item, Awaitable): + future = asyncio.run_coroutine_threadsafe(pipe_item.item, self._ensure_async_pool()) + else: + future = self._ensure_thread_pool().submit(pipe_item.item) + # print(future) + self._futures.append(FuturePipeItem(future, pipe_item.step, pipe_item.pipe)) # type: ignore + # pipe item consumed for now, request a new one + pipe_item = None + continue + else: + # print("maximum futures exceeded, waiting") + sleep(self.futures_poll_interval) + # try same item later + continue + + # if we are at the end of the pipe then yield element + # print(pipe_item) + if pipe_item.step == len(pipe_item.pipe) - 1: + # must be resolved + if isinstance(pipe_item.item, (Iterator, Awaitable)) or callable(pipe_item.pipe): + raise PipeItemProcessingError("Pipe item not processed", pipe_item) + # mypy not able to figure out that item was resolved + return pipe_item # type: ignore + + # advance to next step + step = pipe_item.pipe[pipe_item.step + 1] + assert callable(step) + item = step(pipe_item.item) + pipe_item = ResolvablePipeItem(item, pipe_item.step + 1, pipe_item.pipe) # type: ignore + + + def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: + # lazily create async pool is separate thread + if self._async_pool: + return self._async_pool + + def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + self._async_pool = asyncio.new_event_loop() + self._async_pool_thread = Thread(target=start_background_loop, args=(self._async_pool,), daemon=True) + self._async_pool_thread.start() + + # start or return async pool + return self._async_pool + + def _ensure_thread_pool(self) -> ThreadPoolExecutor: + # lazily start or return thread pool + if self._thread_pool: + return self._thread_pool + + self._thread_pool = ThreadPoolExecutor(self.worker_threads) + return self._thread_pool + + def __enter__(self) -> "PipeIterator": + return self + + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType) -> None: + + def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: + loop.stop() + + for f, _, _ in self._futures: + if not f.done(): + f.cancel() + print("stopping loop") + if self._async_pool: + self._async_pool.call_soon_threadsafe(stop_background_loop, self._async_pool) + print("joining thread") + self._async_pool_thread.join() + self._async_pool = None + self._async_pool_thread = None + if self._thread_pool: + self._thread_pool.shutdown(wait=True) + self._thread_pool = None + + def _next_future(self) -> int: + return next((i for i, val in enumerate(self._futures) if val.item.done()), -1) + + def _resolve_futures(self) -> ResolvablePipeItem: + # no futures at all + if len(self._futures) == 0: + return None + + # anything done? + idx = self._next_future() + if idx == -1: + # nothing done + return None + + future, step, pipe = self._futures.pop(idx) + + if future.cancelled(): + # get next future + return self._resolve_futures() + + if future.exception(): + raise future.exception() + + return ResolvablePipeItem(future.result(), step, pipe) + + def _get_source_item(self) -> ResolvablePipeItem: + # no more sources to iterate + if len(self._sources) == 0: + return None + + # get items from last added iterator, this makes the overall Pipe as close to FIFO as possible + gen, step, pipe = self._sources[-1] + try: + item = next(gen) + # full pipe item may be returned, this is used by ForkPipe step + # to redirect execution of an item to another pipe + if isinstance(item, ResolvablePipeItem): + return item + else: + # keep the item assigned step and pipe + return ResolvablePipeItem(item, step, pipe) + except StopIteration: + # remove empty iterator and try another source + self._sources.pop() + return self._get_source_item() + + +class PipeException(DltException): + pass + + +class CreatePipeException(PipeException): + pass + + +class PipeItemProcessingError(PipeException): + pass + diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py index dac3425811..54c2113078 100644 --- a/experiments/pipeline/pipeline.py +++ b/experiments/pipeline/pipeline.py @@ -14,9 +14,7 @@ from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner from dlt.common.schema.utils import normalize_schema_name -from dlt.common.storages.live_schema_storage import LiveSchemaStorage -from dlt.common.storages.normalize_storage import NormalizeStorage -from dlt.common.storages.schema_storage import SchemaStorage +from dlt.common.storages import LiveSchemaStorage, NormalizeStorage from dlt.common.configuration import make_configuration, RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, ProductionNormalizeVolumeConfiguration from dlt.common.schema.schema import Schema @@ -33,7 +31,7 @@ from experiments.pipeline.configuration import get_config from experiments.pipeline.exceptions import PipelineConfigMissing, PipelineConfiguredException, MissingDependencyException, PipelineStepFailed -from experiments.pipeline.sources import SourceTables, TResolvableDataItem +from experiments.pipeline.sources import DltSource, TResolvableDataItem TConnectionString = NewType("TConnectionString", str) @@ -209,7 +207,7 @@ def extract( @overload def extract( self, - data: SourceTables, + data: DltSource, max_parallel_iterators: int = 1, max_parallel_data_items: int = 20, schema: Schema = None @@ -221,7 +219,7 @@ def extract( @with_state_sync def extract( self, - data: Union[SourceTables, Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], + data: Union[DltSource, Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], table_name = None, write_disposition = None, parent = None, @@ -435,7 +433,7 @@ def _extract_iterator(self, default_table_name: str, items: Sequence[DictStrAny] self._extractor_storage.save_json(f"{load_id}.json", items) self._extractor_storage.commit_events( self.default_schema.name, - self._extractor_storage.storage._make_path(f"{load_id}.json"), + self._extractor_storage.storage.make_full_path(f"{load_id}.json"), default_table_name, len(items), load_id diff --git a/experiments/pipeline/sources.py b/experiments/pipeline/sources.py index b2b5d1ca91..e727a9ea6e 100644 --- a/experiments/pipeline/sources.py +++ b/experiments/pipeline/sources.py @@ -1,86 +1,139 @@ from collections import abc -from typing import Iterable, Iterator, List, Union, Awaitable, Callable, Sequence, TypeVar, cast +from typing import AsyncIterable, Coroutine, Dict, Generator, Iterable, Iterator, List, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any +from dlt.common.exceptions import DltException -from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TTableSchema from dlt.common.typing import TDataItem +from dlt.common.sources import TFunDataItemDynHint, TDirectDataItem +from dlt.common.schema.schema import Schema +from dlt.common.schema.typing import TPartialTableSchema, TTableSchema, TTableSchemaColumns, TWriteDisposition + +from experiments.pipeline.pipe import FilterItem, Pipe, CreatePipeException, PipeIterator + + +class TTableSchemaTemplate(TypedDict, total=False): + name: Union[str, TFunDataItemDynHint] + description: Union[str, TFunDataItemDynHint] + write_disposition: Union[TWriteDisposition, TFunDataItemDynHint] + # table_sealed: Optional[bool] + parent: Union[str, TFunDataItemDynHint] + columns: Union[TTableSchemaColumns, TFunDataItemDynHint] +# async def item(value: str) -> TDataItem: +# return {"str": value} -TDirectDataItem = Union[TDataItem, Sequence[TDataItem]] -TDeferredDataItem = Callable[[], TDirectDataItem] -TAwaitableDataItem = Awaitable[TDirectDataItem] -TResolvableDataItem = Union[TDirectDataItem, TDeferredDataItem, TAwaitableDataItem] +# import asyncio +# print(asyncio.run(item("a"))) +# exit(0) + + +# reveal_type(item) # TBoundItem = TypeVar("TBoundItem", bound=TDataItem) # TDeferreBoundItem = Callable[[], TBoundItem] -class TableMetadataMixin: - def __init__(self, table_schema: TTableSchema, schema: Schema = None, selected_tables: List[str] = None): - self._table_schema = table_schema - self.schema = schema - self._table_name = table_schema["name"] - self.__name__ = self._table_name - self.selected_tables = selected_tables - @property - def table_schema(self): - # TODO: returns unified table schema by merging _schema and _table with table taking precedence - return self._table_schema +class DltResourceSchema: + def __init__(self, name: str, table_schema_template: TTableSchemaTemplate): + self.__name__ = name + self.name = name + self.table_name_hint_fun: TFunDataItemDynHint = None + self.table_has_other_dynamic_props: bool = False + self._table_schema_template: TTableSchemaTemplate = None + self._set_template(table_schema_template) + def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: -_i_info: TableMetadataMixin = None + def _resolve_hint(hint: Union[Any, TFunDataItemDynHint]) -> Any: + if callable(hint): + return hint(item) + else: + return hint + if self.table_name_hint_fun: + if item is None: + raise DataItemRequiredForDynamicTableHints(self.name) + else: + return {k: _resolve_hint(v) for k, v in self._table_schema_template.items()} + else: + return cast(TPartialTableSchema, self._table_schema_template) -def extractor_resolver(i: Union[Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], selected_tables: List[str] = None) -> Iterator[TDataItem]: + def _set_template(self, table_schema_template: TTableSchemaTemplate) -> None: + if callable(table_schema_template.get("name")): + self.table_name_hint_fun = table_schema_template.get("name") + self.table_has_other_dynamic_props = any(callable(v) for k, v in table_schema_template.items() if k != "name") + # check if template contains any functions + if self.table_has_other_dynamic_props and not self.table_name_hint_fun: + raise InvalidTableSchemaTemplate("Table name must be a function if any other table hint is a function") + self._table_schema_template = table_schema_template - if not isinstance(i, abc.Iterator): - i = iter(i) - # for item in i: +class DltResource(Iterable[TDirectDataItem], DltResourceSchema): + def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate): + self.name = pipe.name + self.pipe = pipe + super().__init__(self.name, table_schema_template) + def select(self, *table_names: Iterable[str]) -> "DltResource": + if not self.table_name_hint_fun: + raise CreatePipeException("Table name is not dynamic, table selection impossible") + def _filter(item: TDataItem) -> bool: + return self.table_name(item) in table_names -class TableIterable(abc.Iterable, TableMetadataMixin): - def __init__(self, i, table, schema = None, selected_tables: List[str] = None): - self._data = i - super().__init__(table, schema, selected_tables) + # add filtering function at the end of pipe + self.pipe.add_step(FilterItem(_filter)) + return self - def __iter__(self): - # TODO: this should resolve the _data like we do in the extract method: all awaitables and deferred items are resolved - # possibly in parallel. - resolved_data = extractor_resolver(self._data) - return TableIterator(resolved_data, self._table_schema, self.schema, self.selected_tables) + def map(self) -> None: + pass + def flat_map(self) -> None: + pass + def filter(self) -> None: + pass -class TableIterator(abc.Iterator, TableMetadataMixin): - def __init__(self, i, table, schema = None, selected_tables: List[str] = None): - self.i = i - super().__init__(table, schema, selected_tables) + def __iter__(self) -> Iterator[TDirectDataItem]: + return map(lambda item: item.item, PipeIterator.from_pipe(self.pipe)) - # def __next__(self): - # # export metadata to global variable so it can be read by extractor - # # TODO: remove this hack if possible - # global _i_info - # _i_info = cast(self, TableMetadataMixin) +class DltSource(Iterable[TDirectDataItem]): + def __init__(self, schema: Schema, resources: Sequence[DltResource] = None) -> None: + self._schema = schema + self._resources: Dict[str, DltResource] = {} if resources is None else {r.name:r for r in resources} + self._disabled_resources: Sequence[str] = [] + super().__init__(self) - # if callable(self._table_name): - # else: - # if no table filter selected - # return next(self.i) - # while True: - # ni = next(self.i) - # if callable(self._table_name): - # # table name is a lambda, so resolve table name - # t_n = self._table_name(ni) - # return + def __getitem__(self, i: str) -> DltResource: + return self.resources[i] - # def __iter__(self): - # return self + @property + def resources(self) -> Sequence[DltResource]: + return [r for r in self._resources if r not in self._disabled_resources] + @property + def pipes(self) -> Sequence[Pipe]: + return [r.pipe for r in self._resources.values() if r.name not in self._disabled_resources] + + @property + def schema(self) -> Schema: + return self._schema + + def discover_schema(self) -> Schema: + # extract tables from all resources and update internal schema + # names must be normalized here + return self._schema -class SourceTables(List[TableIterable]): + def select(self, *resource_names: Iterable[str]) -> "DltSource": + pass + + def __iter__(self) -> Iterator[TDirectDataItem]: + return map(lambda item: item.item, PipeIterator.from_pipe(self.pipes)) + + +class DltSourceException(DltException): pass + +# class diff --git a/experiments/pipeline/typing.py b/experiments/pipeline/typing.py index d38bb9cba4..cb06d5a97f 100644 --- a/experiments/pipeline/typing.py +++ b/experiments/pipeline/typing.py @@ -1,4 +1,14 @@ from typing import Literal -TPipelineStep = Literal["extract", "normalize", "load"] \ No newline at end of file +TPipelineStep = Literal["extract", "normalize", "load"] + + +# class TTableSchema(TTableSchema, total=False): +# name: Optional[str] +# description: Optional[str] +# write_disposition: Optional[TWriteDisposition] +# table_sealed: Optional[bool] +# parent: Optional[str] +# filters: Optional[TRowFilters] +# columns: TTableSchemaColumns \ No newline at end of file From 20f6180cce786217241948d7a90e4071be1e1c69 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 20 Sep 2022 22:24:27 +0200 Subject: [PATCH 03/66] adds schema utils to diff tables --- dlt/common/schema/utils.py | 71 +++++++++++++++++++++------ tests/common/schema/test_inference.py | 19 +++---- 2 files changed, 63 insertions(+), 27 deletions(-) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index ae2b0b0e07..08305ed7ce 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -18,8 +18,8 @@ from dlt.common.utils import str2bool from dlt.common.validation import TCustomValidator, validate_dict from dlt.common.schema import detections -from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX, TColumnName, TNormalizersConfig, TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, TDataType, THintType, TTypeDetectionFunc, TTypeDetections, TWriteDisposition -from dlt.common.schema.exceptions import ParentTableNotFoundException, SchemaEngineNoUpgradePathException +from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX, TColumnName, TNormalizersConfig, TPartialTableSchema, TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, TDataType, THintType, TTypeDetectionFunc, TTypeDetections, TWriteDisposition +from dlt.common.schema.exceptions import CannotCoerceColumnException, ParentTableNotFoundException, SchemaEngineNoUpgradePathException, SchemaException, TablePropertiesConflictException RE_LEADING_DIGITS = re.compile(r"^\d+") @@ -28,19 +28,6 @@ DEFAULT_WRITE_DISPOSITION: TWriteDisposition = "append" -# fix a name so it is acceptable as schema name -def normalize_schema_name(name: str) -> str: - # empty and None schema names are not allowed - if not name: - raise ValueError(name) - - # prefix the name starting with digits - if RE_LEADING_DIGITS.match(name): - name = "s" + name - # leave only alphanumeric - return RE_NON_ALPHANUMERIC.sub("", name).lower() - - def apply_defaults(stored_schema: TStoredSchema) -> None: for table_name, table in stored_schema["tables"].items(): # overwrite name @@ -432,7 +419,58 @@ def coerce_type(to_type: TDataType, from_type: TDataType, value: Any) -> Any: raise ValueError(value) -def compare_columns(a: TColumnSchema, b: TColumnSchema) -> bool: +def diff_tables(tab_a: TTableSchema, tab_b: TTableSchema, ignore_table_name: bool = True) -> TPartialTableSchema: + table_name = tab_a["name"] + if not ignore_table_name and table_name != tab_b["name"]: + raise TablePropertiesConflictException(table_name, "name", table_name, tab_b["name"]) + + # check if table properties can be merged + if tab_a.get("parent") != tab_b.get("parent"): + raise TablePropertiesConflictException(table_name, "parent", tab_a.get("parent"), tab_b.get("parent")) + # check if partial table has write disposition set + partial_w_d = tab_b.get("write_disposition") + if partial_w_d: + existing_w_d = tab_a.get("write_disposition") + if existing_w_d != partial_w_d: + raise TablePropertiesConflictException(table_name, "write_disposition", existing_w_d, partial_w_d) + + # get new columns, changes in the column data type or other properties are not allowed + table_columns = tab_a["columns"] + new_columns: List[TColumnSchema] = [] + for column in tab_b["columns"].values(): + column_name = column["name"] + if column_name in table_columns: + # we do not support changing existing columns + if not compare_column(table_columns[column_name], column): + # attempt to update to incompatible columns + raise CannotCoerceColumnException(table_name, column_name, column["data_type"], table_columns[column_name]["data_type"], None) + else: + new_columns.append(column) + + # TODO: compare filters, description etc. + + # return partial table containing only name and properties that differ (column, filters etc.) + return new_table(table_name, columns=new_columns) + + +def compare_tables(tab_a: TTableSchema, tab_b: TTableSchema) -> bool: + try: + diff_table = diff_tables(tab_a, tab_b, ignore_table_name=False) + # columns cannot differ + return len(diff_table["columns"]) == 0 + except SchemaException: + return False + + +def merge_tables(table: TTableSchema, partial_table: TPartialTableSchema) -> TTableSchema: + # merges "partial_table" into "table", preserving the "table" name + diff_table = diff_tables(table, partial_table, ignore_table_name=True) + # add new columns when all checks passed + table["columns"].update(diff_table["columns"]) + return table + + +def compare_column(a: TColumnSchema, b: TColumnSchema) -> bool: return a["data_type"] == b["data_type"] and a["nullable"] == b["nullable"] @@ -501,6 +539,7 @@ def new_table(table_name: str, parent_name: str = None, write_disposition: TWrit else: # set write disposition only for root tables table["write_disposition"] = write_disposition or DEFAULT_WRITE_DISPOSITION + print(f"new table {table_name} cid {id(table['columns'])}") return table diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index 37c94deded..f53b99f7b4 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -373,18 +373,15 @@ def test_update_schema_table_prop_conflict(schema: Schema) -> None: # without write disposition will merge del tab1_u2["write_disposition"] schema.update_schema(tab1_u2) - # child table merge checks recursively - child_tab1 = utils.new_table("child_tab", parent_name="tab_parent") - schema.update_schema(child_tab1) - child_tab1_u1 = deepcopy(child_tab1) - # parent table is replace - child_tab1_u1["write_disposition"] = "append" + # tab1 no write disposition, table update has write disposition + tab1["write_disposition"] = None + tab1_u2["write_disposition"] = "merge" + # this will not merge with pytest.raises(TablePropertiesConflictException) as exc_val: - schema.update_schema(child_tab1_u1) - assert exc_val.value.prop_name == "write_disposition" - # this will pass - child_tab1_u1["write_disposition"] = "replace" - schema.update_schema(child_tab1_u1) + schema.update_schema(tab1_u2) + # both write dispositions are None + tab1_u2["write_disposition"] = None + schema.update_schema(tab1_u2) def test_update_schema_column_conflict(schema: Schema) -> None: From ddd06a92d2c61b182597198dde6d934ad0672e65 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 20 Sep 2022 22:26:21 +0200 Subject: [PATCH 04/66] adds hard linking and path validation to file storage --- dlt/common/file_storage.py | 93 ++++++++++++++-------- tests/common/storages/test_file_storage.py | 77 ++++++++++++++++-- 2 files changed, 130 insertions(+), 40 deletions(-) diff --git a/dlt/common/file_storage.py b/dlt/common/file_storage.py index c626d4af2c..d75009a20b 100644 --- a/dlt/common/file_storage.py +++ b/dlt/common/file_storage.py @@ -1,7 +1,7 @@ import os import tempfile import shutil -from pathlib import Path +import pathvalidate from typing import IO, Any, List from dlt.common.utils import encoding_for_mode @@ -18,9 +18,9 @@ def __init__(self, if makedirs: os.makedirs(storage_path, exist_ok=True) - @classmethod - def from_file(cls, file_path: str, file_type: str = "t",) -> "FileStorage": - return cls(os.path.dirname(file_path), file_type) + # @classmethod + # def from_file(cls, file_path: str, file_type: str = "t",) -> "FileStorage": + # return cls(os.path.dirname(file_path), file_type) def save(self, relative_path: str, data: Any) -> str: return self.save_atomic(self.storage_path, relative_path, data, file_type=self.file_type) @@ -47,14 +47,14 @@ def load(self, relative_path: str) -> Any: return text_file.read() def delete(self, relative_path: str) -> None: - file_path = self._make_path(relative_path) + file_path = self.make_full_path(relative_path) if os.path.isfile(file_path): os.remove(file_path) else: raise FileNotFoundError(file_path) def delete_folder(self, relative_path: str, recursively: bool = False) -> None: - folder_path = self._make_path(relative_path) + folder_path = self.make_full_path(relative_path) if os.path.isdir(folder_path): if recursively: shutil.rmtree(folder_path) @@ -65,17 +65,17 @@ def delete_folder(self, relative_path: str, recursively: bool = False) -> None: def open_file(self, realtive_path: str, mode: str = "r") -> IO[Any]: mode = mode + self.file_type - return open(self._make_path(realtive_path), mode, encoding=encoding_for_mode(mode)) + return open(self.make_full_path(realtive_path), mode, encoding=encoding_for_mode(mode)) def open_temp(self, delete: bool = False, mode: str = "w", file_type: str = None) -> IO[Any]: mode = mode + file_type or self.file_type return tempfile.NamedTemporaryFile(dir=self.storage_path, mode=mode, delete=delete, encoding=encoding_for_mode(mode)) def has_file(self, relative_path: str) -> bool: - return os.path.isfile(self._make_path(relative_path)) + return os.path.isfile(self.make_full_path(relative_path)) def has_folder(self, relative_path: str) -> bool: - return os.path.isdir(self._make_path(relative_path)) + return os.path.isdir(self.make_full_path(relative_path)) def list_folder_files(self, relative_path: str, to_root: bool = True) -> List[str]: """List all files in ``relative_path`` folder @@ -87,7 +87,7 @@ def list_folder_files(self, relative_path: str, to_root: bool = True) -> List[st Returns: List[str]: A list of file names with optional path as per ``to_root`` parameter """ - scan_path = self._make_path(relative_path) + scan_path = self.make_full_path(relative_path) if to_root: # list files in relative path, returning paths relative to storage root return [os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_file()] @@ -97,7 +97,7 @@ def list_folder_files(self, relative_path: str, to_root: bool = True) -> List[st def list_folder_dirs(self, relative_path: str, to_root: bool = True) -> List[str]: # list content of relative path, returning paths relative to storage root - scan_path = self._make_path(relative_path) + scan_path = self.make_full_path(relative_path) if to_root: # list folders in relative path, returning paths relative to storage root return [os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_dir()] @@ -106,25 +106,32 @@ def list_folder_dirs(self, relative_path: str, to_root: bool = True) -> List[str return [e.name for e in os.scandir(scan_path) if e.is_dir()] def create_folder(self, relative_path: str, exists_ok: bool = False) -> None: - os.makedirs(self._make_path(relative_path), exist_ok=exists_ok) - - def copy_cross_storage_atomically(self, dest_volume_root: str, dest_relative_path: str, source_path: str, dest_name: str) -> None: - external_tmp_file = tempfile.mktemp(dir=dest_volume_root) - # first copy to temp file - shutil.copy(self._make_path(source_path), external_tmp_file) - # then rename to dest name - external_dest = os.path.join(dest_volume_root, dest_relative_path, dest_name) - try: - os.rename(external_tmp_file, external_dest) - except Exception: - if os.path.isfile(external_tmp_file): - os.remove(external_tmp_file) - raise + os.makedirs(self.make_full_path(relative_path), exist_ok=exists_ok) + + # def copy_cross_storage_atomically(self, dest_volume_root: str, dest_relative_path: str, source_path: str, dest_name: str) -> None: + # external_tmp_file = tempfile.mktemp(dir=dest_volume_root) + # # first copy to temp file + # shutil.copy(self.make_full_path(source_path), external_tmp_file) + # # then rename to dest name + # external_dest = os.path.join(dest_volume_root, dest_relative_path, dest_name) + # try: + # os.rename(external_tmp_file, external_dest) + # except Exception: + # if os.path.isfile(external_tmp_file): + # os.remove(external_tmp_file) + # raise + + def link_hard(self, from_relative_path: str, to_relative_path: str) -> None: + # note: some interesting stuff on links https://lightrun.com/answers/conan-io-conan-research-investigate-symlinks-and-hard-links + os.link( + self.make_full_path(from_relative_path), + self.make_full_path(to_relative_path) + ) def atomic_rename(self, from_relative_path: str, to_relative_path: str) -> None: os.rename( - self._make_path(from_relative_path), - self._make_path(to_relative_path) + self.make_full_path(from_relative_path), + self.make_full_path(to_relative_path) ) def in_storage(self, path: str) -> bool: @@ -138,11 +145,31 @@ def to_relative_path(self, path: str) -> str: raise ValueError(path) return os.path.relpath(path, start=self.storage_path) - def get_file_stem(self, path: str) -> str: - return Path(os.path.basename(path)).stem + def make_full_path(self, path: str) -> str: + # try to make a relative path is paths are absolute or overlapping + try: + path = self.to_relative_path(path) + except ValueError: + # if path is absolute and cannot be made relative to the storage then cannot be made full path with storage root + if os.path.isabs(path): + raise ValueError(path) - def get_file_name(self, path: str) -> str: - return Path(path).name + # then assume that it is a path relative to storage root + return os.path.join(self.storage_path, path) - def _make_path(self, relative_path: str) -> str: - return os.path.join(self.storage_path, relative_path) + @staticmethod + def validate_file_name_component(name: str) -> None: + # Universal platform bans several characters allowed in POSIX ie. | < \ or "COM1" :) + pathvalidate.validate_filename(name, platform="Universal") + # component cannot contain "." + if "." in name: + raise pathvalidate.error.InvalidCharError(reason="Component name cannot contain . (dots)") + pass + + # @staticmethod + # def get_file_stem(path: str) -> str: + # return Path(os.path.basename(path)).stem + + # @staticmethod + # def get_file_name(path: str) -> str: + # return Path(path).name diff --git a/tests/common/storages/test_file_storage.py b/tests/common/storages/test_file_storage.py index 7299bb777a..46e2bcc653 100644 --- a/tests/common/storages/test_file_storage.py +++ b/tests/common/storages/test_file_storage.py @@ -1,10 +1,73 @@ +import os +import pytest + from dlt.common.file_storage import FileStorage -from dlt.common.utils import encoding_for_mode +from dlt.common.utils import encoding_for_mode, uniq_id + +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, test_storage + + +# FileStorage(TEST_STORAGE_ROOT, makedirs=True) + + +def test_storage_init(test_storage: FileStorage) -> None: + # must be absolute path + assert os.path.isabs(test_storage.storage_path) + # may not contain file name (ends with / or \) + assert os.path.basename(test_storage.storage_path) == "" + + # TODO: write more cases + + +def test_make_full_path(test_storage: FileStorage) -> None: + # fully within storage + path = test_storage.make_full_path("dir/to/file") + assert path.endswith("/" + TEST_STORAGE_ROOT + "/dir/to/file") + # overlapped with storage + path = test_storage.make_full_path(f"{TEST_STORAGE_ROOT}/dir/to/file") + assert path.endswith("/" + TEST_STORAGE_ROOT + "/dir/to/file") + assert path.count(TEST_STORAGE_ROOT) == 1 + # absolute path with different root than TEST_STORAGE_ROOT + with pytest.raises(ValueError): + test_storage.make_full_path(f"/{TEST_STORAGE_ROOT}/dir/to/file") + # absolute overlapping path + path = test_storage.make_full_path(os.path.abspath(f"{TEST_STORAGE_ROOT}/dir/to/file")) + assert path.endswith("/" + TEST_STORAGE_ROOT + "/dir/to/file") + + +def test_hard_links(test_storage: FileStorage) -> None: + content = uniq_id() + test_storage.save("file.txt", content) + test_storage.link_hard("file.txt", "link.txt") + # it is a file + assert test_storage.has_file("link.txt") + # should have same content as file + assert test_storage.load("link.txt") == content + # should be linked + with test_storage.open_file("file.txt", mode="a") as f: + f.write(content) + assert test_storage.load("link.txt") == content * 2 + with test_storage.open_file("link.txt", mode="a") as f: + f.write(content) + assert test_storage.load("file.txt") == content * 3 + # delete original file + test_storage.delete("file.txt") + assert not test_storage.has_file("file.txt") + assert test_storage.load("link.txt") == content * 3 -from tests.utils import TEST_STORAGE +def test_validate_file_name_component() -> None: + # no dots + with pytest.raises(ValueError): + FileStorage.validate_file_name_component("a.b") + # no slashes + with pytest.raises(ValueError): + FileStorage.validate_file_name_component("a/b") + # no backslashes + with pytest.raises(ValueError): + FileStorage.validate_file_name_component("a\\b") -FileStorage(TEST_STORAGE, makedirs=True) + FileStorage.validate_file_name_component("BAN__ANA is allowed") def test_encoding_for_mode() -> None: @@ -17,15 +80,15 @@ def test_encoding_for_mode() -> None: def test_save_atomic_encode() -> None: tstr = "data'ऄअआइ''ईउऊऋऌऍऎए');" - FileStorage.save_atomic(TEST_STORAGE, "file.txt", tstr) - storage = FileStorage(TEST_STORAGE) + FileStorage.save_atomic(TEST_STORAGE_ROOT, "file.txt", tstr) + storage = FileStorage(TEST_STORAGE_ROOT) with storage.open_file("file.txt") as f: assert f.encoding == "utf-8" assert f.read() == tstr bstr = b"axa\0x0\0x0" - FileStorage.save_atomic(TEST_STORAGE, "file.bin", bstr, file_type="b") - storage = FileStorage(TEST_STORAGE, file_type="b") + FileStorage.save_atomic(TEST_STORAGE_ROOT, "file.bin", bstr, file_type="b") + storage = FileStorage(TEST_STORAGE_ROOT, file_type="b") with storage.open_file("file.bin", mode="r") as f: assert hasattr(f, "encoding") is False assert f.read() == bstr From 15d29b55f897cba87e504553ef30fbd6823ae911 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 20 Sep 2022 22:28:19 +0200 Subject: [PATCH 05/66] implements pipe, resource, source and extractor without tests --- dlt/common/data_writers/buffered.py | 2 +- dlt/common/sources.py | 10 +- dlt/pipeline/pipeline.py | 6 +- experiments/pipeline/__init__.py | 4 +- experiments/pipeline/extract.py | 59 +++++---- experiments/pipeline/pipe.py | 97 ++++++++++++--- experiments/pipeline/sources.py | 171 +++++++++++++++++++------- tests/load/redshift/test_pipelines.py | 104 ++++++++-------- 8 files changed, 304 insertions(+), 149 deletions(-) diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index c1c0c3651e..26ca1c28f9 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -74,8 +74,8 @@ def close_writer(self) -> None: if self._writer: # write the footer of a file self._writer.write_footer() + self._file.close() # add file written to the list so we can commit all the files later self.all_files.append(self._file_name) - self._file.close() self._writer = None self._file = None diff --git a/dlt/common/sources.py b/dlt/common/sources.py index 0ae2bc4c48..ebeb1f4a70 100644 --- a/dlt/common/sources.py +++ b/dlt/common/sources.py @@ -1,6 +1,6 @@ from collections import abc from functools import wraps -from typing import Any, Callable, Optional, Sequence, TypeVar, Union, TypedDict +from typing import Any, Callable, Optional, Sequence, TypeVar, Union, TypedDict, List, Awaitable try: from typing_extensions import ParamSpec except ImportError: @@ -20,6 +20,14 @@ _TFunParams = ParamSpec("_TFunParams") +# TODO: cleanup those types +TDirectDataItem = Union[TDataItem, List[TDataItem]] +TDeferredDataItem = Callable[[], TDirectDataItem] +TAwaitableDataItem = Awaitable[TDirectDataItem] +TResolvableDataItem = Union[TDirectDataItem, TDeferredDataItem, TAwaitableDataItem] + +TFunDataItemDynHint = Callable[[TDataItem], Any] + # name of dlt metadata as part of the item DLT_METADATA_FIELD = "_dlt_meta" diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index e3502092c4..4d611fed37 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -13,7 +13,7 @@ from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner from dlt.common.configuration import PoolRunnerConfiguration, make_configuration from dlt.common.file_storage import FileStorage -from dlt.common.schema import Schema, normalize_schema_name +from dlt.common.schema import Schema from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id, is_interactive from dlt.common.sources import DLT_METADATA_FIELD, TItem, with_table_name @@ -90,7 +90,7 @@ def create_pipeline( # create new schema if no default supplied if schema is None: # try to load schema, that will also import it - schema_name = normalize_schema_name(self.pipeline_name) + schema_name = self.pipeline_name try: schema = self._normalize_instance.schema_storage.load_schema(schema_name) except FileNotFoundError: @@ -382,7 +382,7 @@ def _extract_iterator(self, default_table_name: str, items: Sequence[DictStrAny] self.extractor_storage.save_json(f"{load_id}.json", items) self.extractor_storage.commit_events( self.default_schema_name, - self.extractor_storage.storage._make_path(f"{load_id}.json"), + self.extractor_storage.storage.make_full_path(f"{load_id}.json"), default_table_name, len(items), load_id diff --git a/experiments/pipeline/__init__.py b/experiments/pipeline/__init__.py index 68a3156325..f7a07ad9dc 100644 --- a/experiments/pipeline/__init__.py +++ b/experiments/pipeline/__init__.py @@ -1,6 +1,6 @@ -from experiments.pipeline.pipeline import Pipeline +# from experiments.pipeline.pipeline import Pipeline -pipeline = Pipeline() +# pipeline = Pipeline() # def __getattr__(name): # if name == 'y': diff --git a/experiments/pipeline/extract.py b/experiments/pipeline/extract.py index 076562c4d1..8ed57c1486 100644 --- a/experiments/pipeline/extract.py +++ b/experiments/pipeline/extract.py @@ -1,8 +1,8 @@ import os from typing import Dict, List, Sequence, Type -from typing_extensions import reveal_type -from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common import logger +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.utils import uniq_id from dlt.common.sources import TDirectDataItem, TDataItem from dlt.common.schema import utils, TSchemaUpdate @@ -27,15 +27,15 @@ def __init__(self, C: Type[NormalizeVolumeConfiguration]) -> None: def initialize_storage(self) -> None: self.storage.create_folder(ExtractorStorage.EXTRACT_FOLDER, exists_ok=True) - def create_extract_temp_folder(self) -> str: - tf_name = uniq_id() - self.storage.create_folder(os.path.join(ExtractorStorage.EXTRACT_FOLDER, tf_name)) - return tf_name + def create_extract_id(self) -> str: + extract_id = uniq_id() + self.storage.create_folder(self._get_extract_path(extract_id)) + return extract_id - def commit_extract_files(self, temp_folder_name: str, with_delete: bool = True) -> None: - temp_path = os.path.join(os.path.join(ExtractorStorage.EXTRACT_FOLDER, temp_folder_name)) - for file in self.storage.list_folder_files(temp_path, to_root=False): - from_file = os.path.join(temp_path, file) + def commit_extract_files(self, extract_id: str, with_delete: bool = True) -> None: + extract_path = self._get_extract_path(extract_id) + for file in self.storage.list_folder_files(extract_path, to_root=False): + from_file = os.path.join(extract_path, file) to_file = os.path.join(NormalizeStorage.EXTRACTED_FOLDER, file) if with_delete: self.storage.atomic_rename(from_file, to_file) @@ -43,43 +43,51 @@ def commit_extract_files(self, temp_folder_name: str, with_delete: bool = True) # create hardlink which will act as a copy self.storage.link_hard(from_file, to_file) if with_delete: - self.storage.delete_folder(temp_path, recursively=True) + self.storage.delete_folder(extract_path, recursively=True) - def write_data_item(self, schema_name: str, table_name: str, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: + def write_data_item(self, extract_id: str, schema_name: str, table_name: str, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: # unique writer id - writer_id = f"{schema_name}.{table_name}" + writer_id = f"{extract_id}.{schema_name}.{table_name}" writer = self.buffered_writers.get(writer_id, None) - if not writer_id: + if not writer: # assign a jsonl writer with pua encoding for each table, use %s for file id to create required template - writer = BufferedDataWriter("puae-jsonl", NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s")) + template = NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s") + path = self.storage.make_full_path(os.path.join(self._get_extract_path(extract_id), template)) + writer = BufferedDataWriter("puae-jsonl", path) self.buffered_writers[writer_id] = writer # write item(s) writer.write_data_item(item, columns) - def close_writers(self) -> None: + def close_writers(self, extract_id: str) -> None: # flush and close all files - for writer in self.buffered_writers.values(): - writer.close_writer() + for name, writer in self.buffered_writers.items(): + if name.startswith(extract_id): + logger.debug(f"Closing writer for {name} with file {writer._file} and actual name {writer._file_name}") + writer.close_writer() + + def _get_extract_path(self, extract_id: str) -> str: + return os.path.join(ExtractorStorage.EXTRACT_FOLDER, extract_id) def extract(source: DltSource, storage: ExtractorStorage) -> TSchemaUpdate: dynamic_tables: TSchemaUpdate = {} schema = source.schema + extract_id = storage.create_extract_id() def _write_item(table_name: str, item: TDirectDataItem) -> None: # normalize table name before writing so the name match the name in schema # note: normalize function should be cached so there's almost no penalty on frequent calling # note: column schema is not required for jsonl writer used here - storage.write_data_item(schema.name, schema.normalize_table_name(table_name), item, None) + storage.write_data_item(extract_id, schema.name, schema.normalize_table_name(table_name), item, None) def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: - table_name = resource.table_name_hint_fun(item) + table_name = resource._table_name_hint_fun(item) existing_table = dynamic_tables.get(table_name) if existing_table is None: dynamic_tables[table_name] = [resource.table_schema(item)] else: # quick check if deep table merge is required - if resource.table_has_other_dynamic_props: + if resource._table_has_other_dynamic_hints: new_table = resource.table_schema(item) # this merges into existing table in place utils.merge_tables(existing_table[0], new_table) @@ -89,12 +97,12 @@ def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: # write to storage with inferred table name _write_item(table_name, item) - # yield from all selected pipes for pipe_item in PipeIterator.from_pipes(source.pipes): # get partial table from table template - resource = source[pipe_item.pipe.name] - if resource.table_name_hint_fun: + print(pipe_item) + resource = source.resource_by_pipe(pipe_item.pipe) + if resource._table_name_hint_fun: if isinstance(pipe_item.item, List): for item in pipe_item.item: _write_dynamic_table(resource, item) @@ -105,7 +113,8 @@ def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: _write_item(resource.name, pipe_item.item) # flush all buffered writers - storage.close_writers() + storage.close_writers(extract_id) + storage.commit_extract_files(extract_id) # returns set of partial tables return dynamic_tables diff --git a/experiments/pipeline/pipe.py b/experiments/pipeline/pipe.py index 437336b6c1..810b542363 100644 --- a/experiments/pipeline/pipe.py +++ b/experiments/pipeline/pipe.py @@ -6,6 +6,8 @@ from threading import Thread from typing import Optional, Sequence, Union, Callable, Iterable, Iterator, List, NamedTuple, Awaitable, Tuple, Type, TYPE_CHECKING +from dlt.common.typing import TDataItem + if TYPE_CHECKING: TItemFuture = Future[TDirectDataItem] else: @@ -71,7 +73,7 @@ def __call__(self, item: TDirectDataItem) -> Iterator[ResolvablePipeItem]: class FilterItem: - def __init__(self, filter_f: Callable[[TDirectDataItem], bool]) -> None: + def __init__(self, filter_f: Callable[[TDataItem], bool]) -> None: self._filter_f = filter_f def __call__(self, item: TDirectDataItem) -> Optional[TDirectDataItem]: @@ -79,7 +81,7 @@ def __call__(self, item: TDirectDataItem) -> Optional[TDirectDataItem]: if isinstance(item, list): item = [i for i in item if self._filter_f(i)] if not item: - # item was fully consumed by the filer + # item was fully consumed by the filter return None return item else: @@ -87,10 +89,12 @@ def __call__(self, item: TDirectDataItem) -> Optional[TDirectDataItem]: class Pipe: - def __init__(self, name: str, steps: List[TPipeStep] = None, depends_on: "Pipe" = None) -> None: + def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = None) -> None: self.name = name self._steps: List[TPipeStep] = steps or [] - self.depends_on = depends_on + self._backup_steps: List[TPipeStep] = None + self._pipe_id = f"{name}_{id(self)}" + self.parent = parent @classmethod def from_iterable(cls, name: str, gen: Union[Iterable[TResolvableDataItem], Iterator[TResolvableDataItem]]) -> "Pipe": @@ -129,10 +133,29 @@ def fork(self, child_pipe: "Pipe", child_step: int = -1) -> "Pipe": return self def clone(self) -> "Pipe": - return Pipe(self.name, self._steps.copy(), self.depends_on) + p = Pipe(self.name, self._steps.copy(), self.parent) + # clone shares the id with the original + p._pipe_id = self._pipe_id + return p + + # def backup(self) -> None: + # if self.has_backup: + # raise PipeBackupException("Pipe backup already exists, restore pipe first") + # self._backup_steps = self._steps.copy() + + # @property + # def has_backup(self) -> bool: + # return self._backup_steps is not None + + + # def restore(self) -> None: + # if not self.has_backup: + # raise PipeBackupException("No pipe backup to restore") + # self._steps = self._backup_steps + # self._backup_steps = None def add_step(self, step: TPipeStep) -> "Pipe": - if len(self._steps) == 0 and self.depends_on is None: + if len(self._steps) == 0 and self.parent is None: # first element must be iterable or iterator if not isinstance(step, (Iterable, Iterator)): raise CreatePipeException("First step of independent pipe must be Iterable or Iterator") @@ -142,7 +165,7 @@ def add_step(self, step: TPipeStep) -> "Pipe": self._steps.append(step) else: if isinstance(step, (Iterable, Iterator)): - if self.depends_on is not None: + if self.parent is not None: raise CreatePipeException("Iterable or Iterator cannot be a step in dependent pipe") else: raise CreatePipeException("Iterable or Iterator can only be a first step in independent pipe") @@ -152,8 +175,8 @@ def add_step(self, step: TPipeStep) -> "Pipe": return self def full_pipe(self) -> "Pipe": - if self.depends_on: - pipe = self.depends_on.full_pipe().steps + if self.parent: + pipe = self.parent.full_pipe().steps else: pipe = [] @@ -161,6 +184,9 @@ def full_pipe(self) -> "Pipe": pipe.extend(self._steps) return Pipe(self.name, pipe) + def __repr__(self) -> str: + return f"Pipe {self.name} ({self._pipe_id}) at {id(self)}" + class PipeIterator(Iterator[PipeItem]): @@ -177,7 +203,7 @@ def __init__(self, max_parallelism: int = 100, worker_threads: int = 5, futures_ @classmethod def from_pipe(cls, pipe: Pipe, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": - if pipe.depends_on: + if pipe.parent: pipe = pipe.full_pipe() # head must be iterator assert isinstance(pipe.head, Iterator) @@ -189,21 +215,31 @@ def from_pipe(cls, pipe: Pipe, max_parallelism: int = 100, worker_threads: int = @classmethod def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": - # as we add fork steps below, pipes are cloned before use - pipes = [p.clone() for p in pipes] extract = cls(max_parallelism, worker_threads, futures_poll_interval) - for pipe in reversed(pipes): - if pipe.depends_on: + # clone all pipes before iterating (recursively) as we will fork them and this add steps + pipes = PipeIterator.clone_pipes(pipes) + + def _fork_pipeline(pipe: Pipe) -> None: + if pipe.parent: # fork the parent pipe - pipe.depends_on.fork(pipe) + pipe.parent.fork(pipe) # make the parent yield by sending a clone of item to itself with position at the end - if yield_parents: - pipe.depends_on.fork(pipe.depends_on, len(pipe.depends_on) - 1) + if yield_parents and pipe.parent in pipes: + # fork is last step of the pipe so it will yield + pipe.parent.fork(pipe.parent, len(pipe.parent) - 1) + _fork_pipeline(pipe.parent) else: # head of independent pipe must be iterator assert isinstance(pipe.head, Iterator) - # add every head as source + # add every head as source only once + if not any(i.pipe == pipe for i in extract._sources): + print("add to sources: " + pipe.name) extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) + + + for pipe in reversed(pipes): + _fork_pipeline(pipe) + return extract def __next__(self) -> PipeItem: @@ -363,6 +399,31 @@ def _get_source_item(self) -> ResolvablePipeItem: self._sources.pop() return self._get_source_item() + @staticmethod + def clone_pipes(pipes: Sequence[Pipe]) -> Sequence[Pipe]: + # will clone the pipes including the dependent ones + cloned_pipes = [p.clone() for p in pipes] + cloned_pairs = {id(p): c for p, c in zip(pipes, cloned_pipes)} + + for clone in cloned_pipes: + while True: + if not clone.parent: + break + # if already a clone + if clone.parent in cloned_pairs.values(): + break + # clone if parent pipe not yet cloned + if id(clone.parent) not in cloned_pairs: + print("cloning:" + clone.parent.name) + cloned_pairs[id(clone.parent)] = clone.parent.clone() + # replace with clone + print(f"replace depends on {clone.name} to {clone.parent.name}") + clone.parent = cloned_pairs[id(clone.parent)] + # recurr with clone + clone = clone.parent + + return cloned_pipes + class PipeException(DltException): pass diff --git a/experiments/pipeline/sources.py b/experiments/pipeline/sources.py index e727a9ea6e..95a55c1e0c 100644 --- a/experiments/pipeline/sources.py +++ b/experiments/pipeline/sources.py @@ -1,6 +1,9 @@ -from collections import abc -from typing import AsyncIterable, Coroutine, Dict, Generator, Iterable, Iterator, List, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any +import contextlib +from copy import deepcopy +import inspect +from typing import AsyncIterable, AsyncIterator, Coroutine, Dict, Generator, Iterable, Iterator, List, Set, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any from dlt.common.exceptions import DltException +from dlt.common.schema.utils import new_table from dlt.common.typing import TDataItem from dlt.common.sources import TFunDataItemDynHint, TDirectDataItem @@ -18,52 +21,48 @@ class TTableSchemaTemplate(TypedDict, total=False): parent: Union[str, TFunDataItemDynHint] columns: Union[TTableSchemaColumns, TFunDataItemDynHint] -# async def item(value: str) -> TDataItem: -# return {"str": value} - -# import asyncio - -# print(asyncio.run(item("a"))) -# exit(0) - - -# reveal_type(item) -# TBoundItem = TypeVar("TBoundItem", bound=TDataItem) -# TDeferreBoundItem = Callable[[], TBoundItem] - - class DltResourceSchema: - def __init__(self, name: str, table_schema_template: TTableSchemaTemplate): - self.__name__ = name + def __init__(self, name: str, table_schema_template: TTableSchemaTemplate = None): + # self.__name__ = name self.name = name - self.table_name_hint_fun: TFunDataItemDynHint = None - self.table_has_other_dynamic_props: bool = False + self._table_name_hint_fun: TFunDataItemDynHint = None + self._table_has_other_dynamic_hints: bool = False self._table_schema_template: TTableSchemaTemplate = None - self._set_template(table_schema_template) + if table_schema_template: + self._set_template(table_schema_template) def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: + if not self._table_schema_template: + # if table template is not present, generate partial table from name + return new_table(self.name) + def _resolve_hint(hint: Union[Any, TFunDataItemDynHint]) -> Any: if callable(hint): return hint(item) else: return hint - if self.table_name_hint_fun: + # if table template present and has dynamic hints, the data item must be provided + if self._table_name_hint_fun: if item is None: raise DataItemRequiredForDynamicTableHints(self.name) else: - return {k: _resolve_hint(v) for k, v in self._table_schema_template.items()} + cloned_template = deepcopy(self._table_schema_template) + return cast(TPartialTableSchema, {k: _resolve_hint(v) for k, v in cloned_template.items()}) else: return cast(TPartialTableSchema, self._table_schema_template) def _set_template(self, table_schema_template: TTableSchemaTemplate) -> None: - if callable(table_schema_template.get("name")): - self.table_name_hint_fun = table_schema_template.get("name") - self.table_has_other_dynamic_props = any(callable(v) for k, v in table_schema_template.items() if k != "name") - # check if template contains any functions - if self.table_has_other_dynamic_props and not self.table_name_hint_fun: + # if "name" is callable in the template then the table schema requires actual data item to be inferred + name_hint = table_schema_template.get("name") + if callable(name_hint): + self._table_name_hint_fun = name_hint + # check if any other hints in the table template should be inferred from data + self._table_has_other_dynamic_hints = any(callable(v) for k, v in table_schema_template.items() if k != "name") + + if self._table_has_other_dynamic_hints and not self._table_name_hint_fun: raise InvalidTableSchemaTemplate("Table name must be a function if any other table hint is a function") self._table_schema_template = table_schema_template @@ -71,49 +70,109 @@ def _set_template(self, table_schema_template: TTableSchemaTemplate) -> None: class DltResource(Iterable[TDirectDataItem], DltResourceSchema): def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate): self.name = pipe.name - self.pipe = pipe + self._pipe = pipe super().__init__(self.name, table_schema_template) + @classmethod + def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSchemaTemplate = None) -> "DltResource": + # call functions assuming that they do not take any parameters, typically they are generator functions + if callable(data): + data = data() + + if isinstance(data, DltResource): + return data + + if isinstance(data, Pipe): + return cls(data, table_schema_template) + + # several iterable types are not allowed and must be excluded right away + if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): + raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) + + # create resource from iterator or iterable + if isinstance(data, (Iterable, Iterator)): + if inspect.isgenerator(data): + name = name or data.__name__ + else: + name = name or None + if not name: + raise ResourceNameRequired("The DltResource name was not provide or could not be inferred.") + pipe = Pipe.from_iterable(name, data) + return cls(pipe, table_schema_template) + + # some other data type that is not supported + raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) + + def select(self, *table_names: Iterable[str]) -> "DltResource": - if not self.table_name_hint_fun: + if not self._table_name_hint_fun: raise CreatePipeException("Table name is not dynamic, table selection impossible") def _filter(item: TDataItem) -> bool: - return self.table_name(item) in table_names + return self._table_name_hint_fun(item) in table_names # add filtering function at the end of pipe - self.pipe.add_step(FilterItem(_filter)) + self._pipe.add_step(FilterItem(_filter)) return self def map(self) -> None: - pass + raise NotImplementedError() def flat_map(self) -> None: - pass + raise NotImplementedError() def filter(self) -> None: - pass + raise NotImplementedError() def __iter__(self) -> Iterator[TDirectDataItem]: - return map(lambda item: item.item, PipeIterator.from_pipe(self.pipe)) + return map(lambda item: item.item, PipeIterator.from_pipe(self._pipe)) + + def __repr__(self) -> str: + return f"DltResource {self.name} ({self._pipe._pipe_id}) at {id(self)}" + class DltSource(Iterable[TDirectDataItem]): def __init__(self, schema: Schema, resources: Sequence[DltResource] = None) -> None: + self.name = schema.name self._schema = schema - self._resources: Dict[str, DltResource] = {} if resources is None else {r.name:r for r in resources} - self._disabled_resources: Sequence[str] = [] - super().__init__(self) + self._resources: List[DltResource] = list(resources or []) + self._enabled_resource_names: Set[str] = set(r.name for r in self._resources) + + @classmethod + def from_data(cls, schema: Schema, data: Any) -> "DltSource": + # creates source from various forms of data + if isinstance(data, DltSource): + return data + + # several iterable types are not allowed and must be excluded right away + if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): + raise InvalidSourceDataType("Invalid data type for DltSource", type(data)) + + # in case of sequence, enumerate items and convert them into resources + if isinstance(data, Sequence): + resources = [DltResource.from_data(i) for i in data] + else: + resources = [DltResource.from_data(data)] + + return cls(schema, resources) + - def __getitem__(self, i: str) -> DltResource: - return self.resources[i] + def __getitem__(self, name: str) -> List[DltResource]: + if name not in self._enabled_resource_names: + raise KeyError(name) + return [r for r in self._resources if r.name == name] + + def resource_by_pipe(self, pipe: Pipe) -> DltResource: + # identify pipes by memory pointer + return next(r for r in self._resources if r._pipe._pipe_id is pipe._pipe_id) @property def resources(self) -> Sequence[DltResource]: - return [r for r in self._resources if r not in self._disabled_resources] + return [r for r in self._resources if r.name in self._enabled_resource_names] @property def pipes(self) -> Sequence[Pipe]: - return [r.pipe for r in self._resources.values() if r.name not in self._disabled_resources] + return [r._pipe for r in self._resources if r.name in self._enabled_resource_names] @property def schema(self) -> Schema: @@ -121,19 +180,37 @@ def schema(self) -> Schema: def discover_schema(self) -> Schema: # extract tables from all resources and update internal schema - # names must be normalized here + for r in self._resources: + # names must be normalized here + with contextlib.suppress(DataItemRequiredForDynamicTableHints): + partial_table = self._schema.normalize_table_identifiers(r.table_schema()) + self._schema.update_schema(partial_table) return self._schema - def select(self, *resource_names: Iterable[str]) -> "DltSource": - pass + def select(self, *resource_names: str) -> "DltSource": + # make sure all selected resources exist + for name in resource_names: + self.__getitem__(name) + self._enabled_resource_names = set(resource_names) + return self + def __iter__(self) -> Iterator[TDirectDataItem]: - return map(lambda item: item.item, PipeIterator.from_pipe(self.pipes)) + return map(lambda item: item.item, PipeIterator.from_pipes(self.pipes)) + + def __repr__(self) -> str: + return f"DltSource {self.name} at {id(self)}" class DltSourceException(DltException): pass +class DataItemRequiredForDynamicTableHints(DltException): + def __init__(self, resource_name: str) -> None: + self.resource_name = resource_name + super().__init__(f"Instance of Data Item required to generate table schema in resource {resource_name}") + + # class diff --git a/tests/load/redshift/test_pipelines.py b/tests/load/redshift/test_pipelines.py index 77d16c4932..30e2936708 100644 --- a/tests/load/redshift/test_pipelines.py +++ b/tests/load/redshift/test_pipelines.py @@ -1,66 +1,66 @@ -import os -import pytest -from os import environ +# import os +# import pytest +# from os import environ -from dlt.common.schema.schema import Schema -from dlt.common.utils import uniq_id -from dlt.pipeline import Pipeline, PostgresPipelineCredentials -from dlt.pipeline.exceptions import InvalidPipelineContextException +# from dlt.common.schema.schema import Schema +# from dlt.common.utils import uniq_id +# from dlt.pipeline import Pipeline, PostgresPipelineCredentials +# from dlt.pipeline.exceptions import InvalidPipelineContextException -from tests.utils import autouse_root_storage, TEST_STORAGE +# from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT -FAKE_CREDENTIALS = PostgresPipelineCredentials("redshift", None, None, None, None) +# FAKE_CREDENTIALS = PostgresPipelineCredentials("redshift", None, None, None, None) -def test_empty_default_schema_name() -> None: - p = Pipeline("test_empty_default_schema_name") - FAKE_CREDENTIALS.DEFAULT_DATASET = environ["DEFAULT_DATASET"] = "test_empty_default_schema_name" + uniq_id() - p.create_pipeline(FAKE_CREDENTIALS, os.path.join(TEST_STORAGE, FAKE_CREDENTIALS.DEFAULT_DATASET), Schema("default")) - p.extract(iter(["a", "b", "c"]), table_name="test") - p.normalize() - p.load() +# def test_empty_default_schema_name() -> None: +# p = Pipeline("test_empty_default_schema_name") +# FAKE_CREDENTIALS.DEFAULT_DATASET = environ["DEFAULT_DATASET"] = "test_empty_default_schema_name" + uniq_id() +# p.create_pipeline(FAKE_CREDENTIALS, os.path.join(TEST_STORAGE_ROOT, FAKE_CREDENTIALS.DEFAULT_DATASET), Schema("default")) +# p.extract(iter(["a", "b", "c"]), table_name="test") +# p.normalize() +# p.load() - # delete data - with p.sql_client() as c: - c.drop_dataset() +# # delete data +# with p.sql_client() as c: +# c.drop_dataset() - # try to restore pipeline - r_p = Pipeline("test_empty_default_schema_name") - r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) - schema = r_p.get_default_schema() - assert schema.name == "default" +# # try to restore pipeline +# r_p = Pipeline("test_empty_default_schema_name") +# r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) +# schema = r_p.get_default_schema() +# assert schema.name == "default" -def test_create_wipes_working_dir() -> None: - p = Pipeline("test_create_wipes_working_dir") - FAKE_CREDENTIALS.DEFAULT_DATASET = environ["DEFAULT_DATASET"] = "test_create_wipes_working_dir" + uniq_id() - p.create_pipeline(FAKE_CREDENTIALS, working_dir=os.path.join(TEST_STORAGE, FAKE_CREDENTIALS.DEFAULT_DATASET), schema=Schema("table")) - p.extract(iter(["a", "b", "c"]), table_name="test") - p.normalize() - assert len(p.list_normalized_loads()) > 0 +# def test_create_wipes_working_dir() -> None: +# p = Pipeline("test_create_wipes_working_dir") +# FAKE_CREDENTIALS.DEFAULT_DATASET = environ["DEFAULT_DATASET"] = "test_create_wipes_working_dir" + uniq_id() +# p.create_pipeline(FAKE_CREDENTIALS, working_dir=os.path.join(TEST_STORAGE_ROOT, FAKE_CREDENTIALS.DEFAULT_DATASET), schema=Schema("table")) +# p.extract(iter(["a", "b", "c"]), table_name="test") +# p.normalize() +# assert len(p.list_normalized_loads()) > 0 - # try to restore pipeline - r_p = Pipeline("test_create_wipes_working_dir") - r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) - assert len(r_p.list_normalized_loads()) > 0 - schema = r_p.get_default_schema() - assert schema.name == "table" +# # try to restore pipeline +# r_p = Pipeline("test_create_wipes_working_dir") +# r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) +# assert len(r_p.list_normalized_loads()) > 0 +# schema = r_p.get_default_schema() +# assert schema.name == "table" - # create pipeline in the same dir - p = Pipeline("overwrite_old") - # FAKE_CREDENTIALS.DEFAULT_DATASET = "new" - p.create_pipeline(FAKE_CREDENTIALS, working_dir=os.path.join(TEST_STORAGE, FAKE_CREDENTIALS.DEFAULT_DATASET), schema=Schema("matrix")) - assert len(p.list_normalized_loads()) == 0 +# # create pipeline in the same dir +# p = Pipeline("overwrite_old") +# # FAKE_CREDENTIALS.DEFAULT_DATASET = "new" +# p.create_pipeline(FAKE_CREDENTIALS, working_dir=os.path.join(TEST_STORAGE_ROOT, FAKE_CREDENTIALS.DEFAULT_DATASET), schema=Schema("matrix")) +# assert len(p.list_normalized_loads()) == 0 - # old pipeline is still functional but storage is wiped out - # TODO: but should be inactive - coming in API v2 - # with pytest.raises(InvalidPipelineContextException): - assert len(r_p.list_normalized_loads()) == 0 +# # old pipeline is still functional but storage is wiped out +# # TODO: but should be inactive - coming in API v2 +# # with pytest.raises(InvalidPipelineContextException): +# assert len(r_p.list_normalized_loads()) == 0 - # so recreate it - r_p = Pipeline("overwrite_old") - r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) - assert len(r_p.list_normalized_loads()) == 0 - schema = r_p.get_default_schema() - assert schema.name == "matrix" +# # so recreate it +# r_p = Pipeline("overwrite_old") +# r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) +# assert len(r_p.list_normalized_loads()) == 0 +# schema = r_p.get_default_schema() +# assert schema.name == "matrix" From 533f5767f64c03efc2d7235cd9741f946fb4a5e9 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 20 Sep 2022 22:28:58 +0200 Subject: [PATCH 06/66] adds pathvalidate to deps --- poetry.lock | 14 +++++++++++++- pyproject.toml | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index db4b1ed3ab..2a553a87a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -892,6 +892,17 @@ python-versions = "*" [package.dependencies] future = "*" +[[package]] +name = "pathvalidate" +version = "2.5.2" +description = "pathvalidate is a Python library to sanitize/validate a string such as filenames/file-paths/etc." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +test = ["allpairspy", "click", "faker", "pytest (>=6.0.1)", "pytest-discord (>=0.0.6)", "pytest-md-report (>=0.0.12)"] + [[package]] name = "pbr" version = "5.10.0" @@ -1486,7 +1497,7 @@ redshift = ["psycopg2-binary", "psycopg2cffi"] [metadata] lock-version = "1.1" python-versions = "^3.8,<3.11" -content-hash = "d04bbf2afa3c4f46ef5725465da8baad95da271da965408c208c2557b6af198a" +content-hash = "e231693ee02a89e14e8168e0b4d74c284e3f9b0102f211d92efcb62723e19abb" [metadata.files] agate = [ @@ -1927,6 +1938,7 @@ parsedatetime = [ {file = "parsedatetime-2.4-py2-none-any.whl", hash = "sha256:9ee3529454bf35c40a77115f5a596771e59e1aee8c53306f346c461b8e913094"}, {file = "parsedatetime-2.4.tar.gz", hash = "sha256:3d817c58fb9570d1eec1dd46fa9448cd644eeed4fb612684b02dfda3a79cb84b"}, ] +pathvalidate = [] pbr = [] pendulum = [ {file = "pendulum-2.1.2-cp27-cp27m-macosx_10_15_x86_64.whl", hash = "sha256:b6c352f4bd32dff1ea7066bd31ad0f71f8d8100b9ff709fb343f3b86cee43efe"}, diff --git a/pyproject.toml b/pyproject.toml index 83181dc616..9601caebe4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ randomname = "^0.1.5" tzdata = "^2022.1" tomlkit = "^0.11.3" asyncstdlib = "^3.10.5" +pathvalidate = "^2.5.2" [tool.poetry.dev-dependencies] From 3db4695da5dcba231cf22d51bee029f0999de5fe Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 20 Sep 2022 22:32:24 +0200 Subject: [PATCH 07/66] adds additional caps to loader clients, fixes tests to use new writers --- dlt/load/bigquery/client.py | 8 ++- dlt/load/client_base.py | 2 +- dlt/load/dummy/client.py | 8 ++- dlt/load/dummy/configuration.py | 2 +- dlt/load/load.py | 5 +- dlt/load/redshift/client.py | 14 +++--- dlt/load/typing.py | 6 ++- tests/load/bigquery/test_bigquery_client.py | 10 ++-- tests/load/redshift/test_redshift_client.py | 54 ++++++++++++--------- tests/load/test_client.py | 24 ++++----- tests/load/test_dummy_client.py | 8 +-- tests/load/utils.py | 16 +++--- 12 files changed, 89 insertions(+), 68 deletions(-) diff --git a/dlt/load/bigquery/client.py b/dlt/load/bigquery/client.py index 27a73efea4..d1a01e8cec 100644 --- a/dlt/load/bigquery/client.py +++ b/dlt/load/bigquery/client.py @@ -13,7 +13,7 @@ from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.configuration import GcpClientCredentials -from dlt.common.dataset_writers import escape_bigquery_identifier +from dlt.common.data_writers import escape_bigquery_identifier from dlt.common.schema import TColumnSchema, TDataType, Schema, TTableSchemaColumns from dlt.load.typing import LoadJobStatus, DBCursor, TLoaderCapabilities @@ -372,7 +372,11 @@ def capabilities(cls) -> TLoaderCapabilities: "preferred_loader_file_format": "jsonl", "supported_loader_file_formats": ["jsonl"], "max_identifier_length": 1024, - "max_column_length": 300 + "max_column_length": 300, + "max_query_length": 1024 * 1024, + "is_max_query_length_in_bytes": False, + "max_text_data_type_length": 10 * 1024 * 1024, + "is_max_text_data_type_length_in_bytes": True } @classmethod diff --git a/dlt/load/client_base.py b/dlt/load/client_base.py index d58671ae27..92a92faa2c 100644 --- a/dlt/load/client_base.py +++ b/dlt/load/client_base.py @@ -208,7 +208,7 @@ def _build_schema_update_sql(self) -> List[str]: def _create_table_update(self, table_name: str, storage_table: TTableSchemaColumns) -> Sequence[TColumnSchema]: # compare table with stored schema and produce delta - updates = self.schema.get_schema_update_for(table_name, storage_table) + updates = self.schema.get_new_columns(table_name, storage_table) logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") return updates diff --git a/dlt/load/dummy/client.py b/dlt/load/dummy/client.py index 7e356e7a40..76b724001e 100644 --- a/dlt/load/dummy/client.py +++ b/dlt/load/dummy/client.py @@ -1,7 +1,7 @@ import random from types import TracebackType from typing import Dict, Tuple, Type -from dlt.common.dataset_writers import TLoaderFileFormat +from dlt.common.data_writers import TLoaderFileFormat from dlt.common import pendulum from dlt.common.schema import Schema @@ -132,7 +132,11 @@ def capabilities(cls) -> TLoaderCapabilities: "preferred_loader_file_format": cls.CONFIG.LOADER_FILE_FORMAT, "supported_loader_file_formats": [cls.CONFIG.LOADER_FILE_FORMAT], "max_identifier_length": 127, - "max_column_length": 127 + "max_column_length": 127, + "max_query_length": 8 * 1024 * 1024, + "is_max_query_length_in_bytes": True, + "max_text_data_type_length": 65535, + "is_max_text_data_type_length_in_bytes": True } @classmethod diff --git a/dlt/load/dummy/configuration.py b/dlt/load/dummy/configuration.py index 79c414fd50..87eef77817 100644 --- a/dlt/load/dummy/configuration.py +++ b/dlt/load/dummy/configuration.py @@ -2,7 +2,7 @@ from dlt.common.typing import StrAny from dlt.common.configuration import make_configuration -from dlt.common.dataset_writers import TLoaderFileFormat +from dlt.common.data_writers import TLoaderFileFormat from dlt.load.configuration import LoaderClientConfiguration diff --git a/dlt/load/load.py b/dlt/load/load.py index 7b5516c9cf..54ed8fa4f1 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -10,7 +10,7 @@ from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchema -from dlt.common.storages.load_storage import LoadStorage +from dlt.common.storages import LoadStorage from dlt.common.telemetry import get_logging_extras, set_gauge_all_labels from dlt.common.typing import StrAny @@ -95,7 +95,7 @@ def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> O table = self.get_load_table(schema, job_info.table_name, file_path) if table["write_disposition"] not in ["append", "replace"]: raise LoadClientUnsupportedWriteDisposition(job_info.table_name, table["write_disposition"], file_path) - job = client.start_file_load(table, self.load_storage.storage._make_path(file_path)) + job = client.start_file_load(table, self.load_storage.storage.make_full_path(file_path)) except (LoadClientTerminalException, TerminalValueError): # if job irreversibly cannot be started, mark it as failed logger.exception(f"Terminal problem with spooling job {file_path}") @@ -240,6 +240,7 @@ def run(self, pool: ThreadPool) -> TRunMetrics: self.load_counter.inc() logger.metrics("Load package metrics", extra=get_logging_extras([self.load_counter])) else: + # TODO: this loop must be urgently removed. while True: remaining_jobs = self.complete_jobs(load_id, jobs) if len(remaining_jobs) == 0: diff --git a/dlt/load/redshift/client.py b/dlt/load/redshift/client.py index a7dba902d4..764c04ea3b 100644 --- a/dlt/load/redshift/client.py +++ b/dlt/load/redshift/client.py @@ -14,7 +14,7 @@ from dlt.common.typing import StrAny from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.dataset_writers import escape_redshift_identifier +from dlt.common.data_writers import escape_redshift_identifier from dlt.common.schema import COLUMN_HINTS, TColumnSchema, TColumnSchemaBase, TDataType, THintType, Schema, TTableSchemaColumns, add_missing_hints from dlt.common.schema.typing import TTableSchema, TWriteDisposition @@ -139,10 +139,6 @@ def fully_qualified_dataset_name(self) -> str: class RedshiftInsertLoadJob(LoadJob): - - MAX_STATEMENT_SIZE = 8 * 1024 * 1024 - - def __init__(self, table_name: str, write_disposition: TWriteDisposition, file_path: str, sql_client: SqlClientBase["psycopg2.connection"]) -> None: super().__init__(JobClientBase.get_file_name_from_file_path(file_path)) self._sql_client = sql_client @@ -174,7 +170,7 @@ def _insert(self, qualified_table_name: str, write_disposition: TWriteDispositio if write_disposition == "replace": insert_sql.append(SQL("DELETE FROM {};").format(SQL(qualified_table_name))) # is_eof = False - while content := f.read(RedshiftInsertLoadJob.MAX_STATEMENT_SIZE): + while content := f.read(RedshiftClient.capabilities()["max_query_length"] // 2): # read one more line in order to # 1. complete the content which ends at "random" position, not an end line # 2. to modify it's ending without a need to re-allocating the 8MB of "content" @@ -343,7 +339,11 @@ def capabilities(cls) -> TLoaderCapabilities: "preferred_loader_file_format": "insert_values", "supported_loader_file_formats": ["insert_values"], "max_identifier_length": 127, - "max_column_length": 127 + "max_column_length": 127, + "max_query_length": 16 * 1024 * 1024, + "is_max_query_length_in_bytes": True, + "max_text_data_type_length": 65535, + "is_max_text_data_type_length_in_bytes": True } @classmethod diff --git a/dlt/load/typing.py b/dlt/load/typing.py index 62ff129886..b103cc6719 100644 --- a/dlt/load/typing.py +++ b/dlt/load/typing.py @@ -1,6 +1,6 @@ from typing import Any, AnyStr, List, Literal, Optional, Tuple, TypeVar, TypedDict -from dlt.common.dataset_writers import TLoaderFileFormat +from dlt.common.data_writers import TLoaderFileFormat LoadJobStatus = Literal["running", "failed", "retry", "completed"] @@ -13,6 +13,10 @@ class TLoaderCapabilities(TypedDict): supported_loader_file_formats: List[TLoaderFileFormat] max_identifier_length: int max_column_length: int + max_query_length: int + is_max_query_length_in_bytes: bool + max_text_data_type_length: int + is_max_text_data_type_length_in_bytes: bool # type for dbapi cursor diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index e22ea64107..e70b0bb9dc 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -12,8 +12,8 @@ from dlt.load import Load from dlt.load.bigquery.client import BigQueryClient -from tests.utils import TEST_STORAGE, delete_storage -from tests.load.utils import cm_yield_client_with_storage, expect_load_file, prepare_table, yield_client_with_storage +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage +from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage, cm_yield_client_with_storage @pytest.fixture(scope="module") @@ -23,12 +23,12 @@ def client() -> Iterator[BigQueryClient]: @pytest.fixture def file_storage() -> FileStorage: - return FileStorage(TEST_STORAGE, file_type="b", makedirs=True) + return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) @pytest.fixture(autouse=True) def auto_delete_storage() -> None: - delete_storage() + delete_test_storage() def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) -> None: @@ -61,7 +61,7 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) # start a job from the same file. it should fallback to retrieve job silently - r_job = client.start_file_load(client.schema.get_table(user_table_name), file_storage._make_path(job.file_name())) + r_job = client.start_file_load(client.schema.get_table(user_table_name), file_storage.make_full_path(job.file_name())) assert r_job.status() == "completed" diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index c799bc2669..020bb42a52 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -1,5 +1,6 @@ from typing import Iterator import pytest +from unittest.mock import patch from dlt.common import pendulum, Decimal from dlt.common.arithmetics import numeric_default_context @@ -11,18 +12,18 @@ from dlt.load import Load from dlt.load.redshift.client import RedshiftClient, RedshiftInsertLoadJob, psycopg2 -from tests.utils import TEST_STORAGE, delete_storage, skipifpypy +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage @pytest.fixture def file_storage() -> FileStorage: - return FileStorage(TEST_STORAGE, file_type="b", makedirs=True) + return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) @pytest.fixture(autouse=True) def auto_delete_storage() -> None: - delete_storage() + delete_test_storage() @pytest.fixture(scope="module") @@ -94,13 +95,15 @@ def test_long_names(client: RedshiftClient) -> None: @skipifpypy def test_loading_errors(client: RedshiftClient, file_storage: FileStorage) -> None: + caps = client.capabilities() + user_table_name = prepare_table(client) # insert string longer than redshift maximum insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" # try some unicode value - redshift checks the max length based on utf-8 representation, not the number of characters # max_len_str = 'उ' * (65535 // 3) + 1 -> does not fit # max_len_str = 'a' * 65535 + 1 -> does not fit - max_len_str = 'उ' * ((65535 // 3) + 1) + max_len_str = 'उ' * ((caps["max_text_data_type_length"] // 3) + 1) # max_len_str_b = max_len_str.encode("utf-8") # print(len(max_len_str_b)) row_id = uniq_id() @@ -157,10 +160,13 @@ def test_loading_errors(client: RedshiftClient, file_storage: FileStorage) -> No def test_query_split(client: RedshiftClient, file_storage: FileStorage) -> None: - max_statement_size = RedshiftInsertLoadJob.MAX_STATEMENT_SIZE - try: - # this guarantees that we execute inserts line by line - RedshiftInsertLoadJob.MAX_STATEMENT_SIZE = 1 + mocked_caps = RedshiftClient.capabilities() + # this guarantees that we execute inserts line by line + mocked_caps["max_query_length"] = 2 + + with patch.object(RedshiftClient, "capabilities") as caps: + caps.return_value = mocked_caps + print(RedshiftClient.capabilities()) user_table_name = prepare_table(client) insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}')" @@ -182,19 +188,23 @@ def test_query_split(client: RedshiftClient, file_storage: FileStorage) -> None: assert ids == v_ids - finally: - RedshiftInsertLoadJob.MAX_STATEMENT_SIZE = max_statement_size - @pytest.mark.skip -def test_maximum_statement(client: RedshiftClient, file_storage: FileStorage) -> None: - assert RedshiftInsertLoadJob.MAX_STATEMENT_SIZE == 20 * 1024 * 1024, "to enable this test, you must increase RedshiftInsertLoadJob.MAX_STATEMENT_SIZE = 20 * 1024 * 1024" - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'){}" - insert_sql = insert_sql + insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ",\n") * 150000 - insert_sql += insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ";") +@skipifpypy +def test_maximum_query_size(client: RedshiftClient, file_storage: FileStorage) -> None: + mocked_caps = RedshiftClient.capabilities() + # this guarantees that we cross the redshift query limit + mocked_caps["max_query_length"] = 2 * 20 * 1024 * 1024 - user_table_name = prepare_table(client) - with pytest.raises(LoadClientTerminalInnerException) as exv: - expect_load_file(client, file_storage, insert_sql, user_table_name) - # psycopg2.errors.SyntaxError: Statement is too large. Statement Size: 20971754 bytes. Maximum Allowed: 16777216 bytes - assert type(exv.value.inner_exc) is psycopg2.ProgrammingError \ No newline at end of file + with patch.object(RedshiftClient, "capabilities") as caps: + caps.return_value = mocked_caps + + insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" + insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'){}" + insert_sql = insert_sql + insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ",\n") * 150000 + insert_sql += insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ";") + + user_table_name = prepare_table(client) + with pytest.raises(LoadClientTerminalInnerException) as exv: + expect_load_file(client, file_storage, insert_sql, user_table_name) + # psycopg2.errors.SyntaxError: Statement is too large. Statement Size: 20971754 bytes. Maximum Allowed: 16777216 bytes + assert type(exv.value.inner_exc) is psycopg2.errors.SyntaxError \ No newline at end of file diff --git a/tests/load/test_client.py b/tests/load/test_client.py index 483b348749..e547605027 100644 --- a/tests/load/test_client.py +++ b/tests/load/test_client.py @@ -12,9 +12,9 @@ from dlt.load.client_base import DBCursor, SqlJobClientBase -from tests.utils import TEST_STORAGE, delete_storage +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.common.utils import load_json_case -from tests.load.utils import TABLE_UPDATE, TABLE_ROW, expect_load_file, yield_client_with_storage, cm_yield_client_with_storage, write_dataset, prepare_table +from tests.load.utils import TABLE_UPDATE, TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW, expect_load_file, yield_client_with_storage, cm_yield_client_with_storage, write_dataset, prepare_table ALL_CLIENTS = ['redshift_client', 'bigquery_client'] @@ -23,12 +23,12 @@ @pytest.fixture def file_storage() -> FileStorage: - return FileStorage(TEST_STORAGE, file_type="b", makedirs=True) + return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) @pytest.fixture(autouse=True) def auto_delete_storage() -> None: - delete_storage() + delete_test_storage() @pytest.fixture(scope="module") @@ -211,7 +211,7 @@ def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) - canonical_name = client.sql_client.make_qualified_table_name(table_name) # write only first row with io.StringIO() as f: - write_dataset(client, f, [rows[0]], rows[0].keys()) + write_dataset(client, f, [rows[0]], client.schema.get_table(table_name)["columns"]) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0] @@ -219,7 +219,7 @@ def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) - assert list(db_row) == list(rows[0].values()) # write second row that contains two nulls with io.StringIO() as f: - write_dataset(client, f, [rows[1]], rows[0].keys()) + write_dataset(client, f, [rows[1]], client.schema.get_table(table_name)["columns"]) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name} WHERE f_int = {rows[1]['f_int']}")[0] @@ -236,7 +236,7 @@ def test_data_writer_string_escape(client: SqlJobClientBase, file_storage: FileS inj_str = f", NULL'); DROP TABLE {canonical_name} --" row["f_str"] = inj_str with io.StringIO() as f: - write_dataset(client, f, [rows[0]], rows[0].keys()) + write_dataset(client, f, [rows[0]], client.schema.get_table(table_name)["columns"]) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0] @@ -248,7 +248,7 @@ def test_data_writer_string_escape_edge(client: SqlJobClientBase, file_storage: rows, table_name = prepare_schema(client, "weird_rows") canonical_name = client.sql_client.make_qualified_table_name(table_name) with io.StringIO() as f: - write_dataset(client, f, rows, rows[0].keys()) + write_dataset(client, f, rows, client.schema.get_table(table_name)["columns"]) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) for i in range(1,len(rows) + 1): @@ -267,7 +267,7 @@ def test_load_with_all_types(client: SqlJobClientBase, write_disposition: str, f canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row with io.StringIO() as f: - write_dataset(client, f, [TABLE_ROW], TABLE_ROW.keys()) + write_dataset(client, f, [TABLE_ROW], TABLE_UPDATE_COLUMNS_SCHEMA) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = list(client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0]) @@ -299,7 +299,7 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: str, fi table_row = deepcopy(TABLE_ROW) table_row["col1"] = idx with io.StringIO() as f: - write_dataset(client, f, [table_row], TABLE_ROW.keys()) + write_dataset(client, f, [table_row], TABLE_UPDATE_COLUMNS_SCHEMA) query = f.getvalue() expect_load_file(client, file_storage, query, t) db_rows = list(client.sql_client.execute_sql(f"SELECT * FROM {t} ORDER BY col1 ASC")) @@ -323,12 +323,12 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No "timestamp": str(pendulum.now()) } with io.StringIO() as f: - write_dataset(client, f, [load_json], load_json.keys()) + write_dataset(client, f, [load_json], client.schema.get_table(user_table_name)["columns"]) dataset = f.getvalue() job = expect_load_file(client, file_storage, dataset, user_table_name) # now try to retrieve the job # TODO: we should re-create client instance as this call is intended to be run after some disruption ie. stopped loader process - r_job = client.restore_file_load(file_storage._make_path(job.file_name())) + r_job = client.restore_file_load(file_storage.make_full_path(job.file_name())) assert r_job.status() == "completed" # use just file name to restore r_job = client.restore_file_load(job.file_name()) diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 0330bb71bf..e1be8b6243 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -20,7 +20,7 @@ from dlt.load import Load, __version__ from dlt.load.dummy.configuration import DummyClientConfiguration -from tests.utils import clean_storage, init_logger +from tests.utils import clean_test_storage, init_logger NORMALIZED_FILES = [ @@ -31,7 +31,7 @@ @pytest.fixture(autouse=True) def storage() -> FileStorage: - clean_storage(init_normalize=True, init_loader=True) + clean_test_storage(init_normalize=True, init_loader=True) @pytest.fixture(scope="module", autouse=True) @@ -316,10 +316,10 @@ def prepare_load_package(load_storage: LoadStorage, cases: Sequence[str]) -> Tup load_storage.create_temp_load_package(load_id) for case in cases: path = f"./tests/load/cases/loading/{case}" - shutil.copy(path, load_storage.storage._make_path(f"{load_id}/{LoadStorage.NEW_JOBS_FOLDER}")) + shutil.copy(path, load_storage.storage.make_full_path(f"{load_id}/{LoadStorage.NEW_JOBS_FOLDER}")) for f in ["schema_updates.json", "schema.json"]: path = f"./tests/load/cases/loading/{f}" - shutil.copy(path, load_storage.storage._make_path(load_id)) + shutil.copy(path, load_storage.storage.make_full_path(load_id)) load_storage.commit_temp_load_package(load_id) schema = load_storage.load_package_schema(load_id) return load_id, schema diff --git a/tests/load/utils.py b/tests/load/utils.py index 9deed5a806..8ad489e51c 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -5,7 +5,7 @@ from dlt.common import json, Decimal from dlt.common.configuration import make_configuration from dlt.common.configuration.schema_volume_configuration import SchemaVolumeConfiguration -from dlt.common.dataset_writers import write_insert_values, write_jsonl +from dlt.common.data_writers import DataWriter from dlt.common.file_storage import FileStorage from dlt.common.schema import TColumnSchema, TTableSchemaColumns from dlt.common.storages.schema_storage import SchemaStorage @@ -64,6 +64,7 @@ "nullable": False }, ] +TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {t["name"]:t for t in TABLE_UPDATE} TABLE_ROW = { "col1": 989127831, @@ -86,7 +87,7 @@ def expect_load_file(client: JobClientBase, file_storage: FileStorage, query: st file_name = uniq_id() file_storage.save(file_name, query.encode("utf-8")) table = Load.get_load_table(client.schema, table_name, file_name) - job = client.start_file_load(table, file_storage._make_path(file_name)) + job = client.start_file_load(table, file_storage.make_full_path(file_name)) while job.status() == "running": sleep(0.5) assert job.file_name() == file_name @@ -131,10 +132,7 @@ def cm_yield_client_with_storage(client_type: str, initial_values: StrAny = None return yield_client_with_storage(client_type, initial_values) -def write_dataset(client: JobClientBase, f: IO[Any], rows: Sequence[StrAny], headers: Iterable[str]) -> None: - if client.capabilities()["preferred_loader_file_format"] == "jsonl": - write_jsonl(f, rows) - elif client.capabilities()["preferred_loader_file_format"] == "insert_values": - write_insert_values(f, rows, headers) - else: - raise ValueError(client.capabilities()["preferred_loader_file_format"]) +def write_dataset(client: JobClientBase, f: IO[Any], rows: Sequence[StrAny], columns_schema: TTableSchemaColumns) -> None: + file_format = client.capabilities()["preferred_loader_file_format"] + writer = DataWriter.from_file_format(file_format, f) + writer.write_all(columns_schema, rows) From be151725e96da930fa347ea4df16ace5aaada1bc Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 20 Sep 2022 22:35:12 +0200 Subject: [PATCH 08/66] refactoring extractor and normalizer code to introduce new data writers and extracted file jsonl format --- dlt/common/storages/__init__.py | 3 + dlt/common/storages/exceptions.py | 2 +- dlt/common/storages/load_storage.py | 27 +++---- dlt/common/storages/normalize_storage.py | 72 ++++++++++--------- dlt/common/storages/schema_storage.py | 1 - dlt/extract/extractor_storage.py | 26 ++++--- dlt/normalize/configuration.py | 2 +- dlt/normalize/normalize.py | 7 +- tests/common/storages/test_loader_storage.py | 4 +- .../common/storages/test_normalize_storage.py | 20 ++---- tests/common/storages/test_schema_storage.py | 8 +-- .../common/storages/test_versioned_storage.py | 36 +++++----- tests/common/test_data_writers.py | 48 +++++++------ tests/normalize/test_normalize.py | 16 ++--- 14 files changed, 136 insertions(+), 136 deletions(-) diff --git a/dlt/common/storages/__init__.py b/dlt/common/storages/__init__.py index e82ac29eb2..8cf7b71216 100644 --- a/dlt/common/storages/__init__.py +++ b/dlt/common/storages/__init__.py @@ -1,2 +1,5 @@ from .schema_storage import SchemaStorage # noqa: F401 from .live_schema_storage import LiveSchemaStorage # noqa: F401 +from .normalize_storage import NormalizeStorage # noqa: F401 +from .versioned_storage import VersionedStorage # noqa: F401 +from .load_storage import LoadStorage # noqa: F401 diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index 162fc76a93..55b55b0711 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -2,7 +2,7 @@ from typing import Iterable from dlt.common.exceptions import DltException -from dlt.common.dataset_writers import TLoaderFileFormat +from dlt.common.data_writers import TLoaderFileFormat class StorageException(DltException): diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index ee2b74b481..456ae9d538 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -5,7 +5,7 @@ from dlt.common import json, pendulum from dlt.common.file_storage import FileStorage -from dlt.common.dataset_writers import TLoaderFileFormat, write_jsonl, write_insert_values +from dlt.common.data_writers import TLoaderFileFormat, DataWriter from dlt.common.configuration import LoadVolumeConfiguration from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema, TSchemaUpdate, TTableSchemaColumns @@ -77,10 +77,12 @@ def create_temp_load_package(self, load_id: str) -> None: def write_temp_job_file(self, load_id: str, table_name: str, table: TTableSchemaColumns, file_id: str, rows: Sequence[StrAny]) -> str: file_name = self.build_job_file_name(table_name, file_id) with self.storage.open_file(join(load_id, LoadStorage.NEW_JOBS_FOLDER, file_name), mode="w") as f: - if self.preferred_file_format == "jsonl": - write_jsonl(f, rows) - elif self.preferred_file_format == "insert_values": - write_insert_values(f, rows, table.keys()) + writer = DataWriter.from_file_format(self.preferred_file_format, f) + writer.write_all(table, rows) + # if self.preferred_file_format == "jsonl": + # write_jsonl(f, rows) + # elif self.preferred_file_format == "insert_values": + # write_insert_values(f, rows, table.keys()) return Path(file_name).name def load_package_schema(self, load_id: str) -> Schema: @@ -188,13 +190,6 @@ def get_package_path(self, load_id: str) -> str: def get_completed_package_path(self, load_id: str) -> str: return join(LoadStorage.LOADED_FOLDER, load_id) - def build_job_file_name(self, table_name: str, file_id: str, retry_count: int = 0) -> str: - if "." in table_name: - raise ValueError(table_name) - if "." in file_id: - raise ValueError(file_id) - return f"{table_name}.{file_id}.{int(retry_count)}.{self.preferred_file_format}" - def job_elapsed_time_seconds(self, file_path: str) -> float: return pendulum.now().timestamp() - os.path.getmtime(file_path) # type: ignore @@ -211,7 +206,7 @@ def _move_job(self, load_id: str, source_folder: TWorkingFolder, dest_folder: TW load_path = self.get_package_path(load_id) dest_path = join(load_path, dest_folder, new_file_name or file_name) self.storage.atomic_rename(join(load_path, source_folder, file_name), dest_path) - return self.storage._make_path(dest_path) + return self.storage.make_full_path(dest_path) def _get_job_folder_path(self, load_id: str, folder: TWorkingFolder) -> str: return join(self.get_package_path(load_id), folder) @@ -219,6 +214,12 @@ def _get_job_folder_path(self, load_id: str, folder: TWorkingFolder) -> str: def _get_job_file_path(self, load_id: str, folder: TWorkingFolder, file_name: str) -> str: return join(self._get_job_folder_path(load_id, folder), file_name) + def build_job_file_name(self, table_name: str, file_id: str, retry_count: int = 0, validate_components: bool = True) -> str: + if validate_components: + FileStorage.validate_file_name_component(table_name) + FileStorage.validate_file_name_component(file_id) + return f"{table_name}.{file_id}.{int(retry_count)}.{self.preferred_file_format}" + @staticmethod def parse_job_file_name(file_name: str) -> TParsedJobFileName: p = Path(file_name) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index fd52ae72e5..3cdbf4a1c5 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Tuple, Type +from typing import List, Sequence, Tuple, Type, NamedTuple from itertools import groupby from pathlib import Path @@ -8,12 +8,16 @@ from dlt.common.storages.versioned_storage import VersionedStorage +class TParsedNormalizeFileName(NamedTuple): + schema_name: str + table_name: str + file_id: str + + class NormalizeStorage(VersionedStorage): STORAGE_VERSION = "1.0.0" EXTRACTED_FOLDER: str = "extracted" # folder within the volume where extracted files to be normalized are stored - EXTRACTED_FILE_EXTENSION = ".extracted.json" - EXTRACTED_FILE_EXTENSION_LEN = len(EXTRACTED_FILE_EXTENSION) def __init__(self, is_owner: bool, C: Type[NormalizeVolumeConfiguration]) -> None: super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(C.NORMALIZE_VOLUME_PATH, "t", makedirs=is_owner)) @@ -31,45 +35,45 @@ def get_grouped_iterator(self, files: Sequence[str]) -> "groupby[str, str]": @staticmethod def chunk_by_events(files: Sequence[str], max_events: int, processing_cores: int) -> List[Sequence[str]]: - # should distribute ~ N events evenly among m cores with fallback for small amounts of events - - def count_events(file_name : str) -> int: - # return event count from file name - return NormalizeStorage.get_events_count(file_name) - - counts = list(map(count_events, files)) - # make a list of files containing ~max_events - events_count = 0 - m = 0 - while events_count < max_events and m < len(files): - events_count += counts[m] - m += 1 - processing_chunks = round(m / processing_cores) - if processing_chunks == 0: - # return one small chunk - return [files] - else: - # should return ~ amount of chunks to fill all the cores - return list(chunks(files[:m], processing_chunks)) - - @staticmethod - def get_events_count(file_name: str) -> int: - return NormalizeStorage._parse_extracted_file_name(file_name)[0] + return [files] + + # # should distribute ~ N events evenly among m cores with fallback for small amounts of events + + # def count_events(file_name : str) -> int: + # # return event count from file name + # return NormalizeStorage.get_events_count(file_name) + + # counts = list(map(count_events, files)) + # # make a list of files containing ~max_events + # events_count = 0 + # m = 0 + # while events_count < max_events and m < len(files): + # events_count += counts[m] + # m += 1 + # processing_chunks = round(m / processing_cores) + # if processing_chunks == 0: + # # return one small chunk + # return [files] + # else: + # # should return ~ amount of chunks to fill all the cores + # return list(chunks(files[:m], processing_chunks)) @staticmethod def get_schema_name(file_name: str) -> str: - return NormalizeStorage._parse_extracted_file_name(file_name)[2] + return NormalizeStorage.parse_normalize_file_name(file_name).schema_name @staticmethod - def build_extracted_file_name(schema_name: str, stem: str, event_count: int, load_id: str) -> str: + def build_extracted_file_stem(schema_name: str, table_name: str, file_id: str) -> str: # builds file name with the extracted data to be passed to normalize - return f"{schema_name}_{stem}_{load_id}_{event_count}{NormalizeStorage.EXTRACTED_FILE_EXTENSION}" + return f"{schema_name}.{table_name}.{file_id}" @staticmethod - def _parse_extracted_file_name(file_name: str) -> Tuple[int, str, str]: + def parse_normalize_file_name(file_name: str) -> TParsedNormalizeFileName: # parse extracted file name and returns (events found, load id, schema_name) - if not file_name.endswith(NormalizeStorage.EXTRACTED_FILE_EXTENSION): + if not file_name.endswith("jsonl"): raise ValueError(file_name) - parts = Path(file_name[:-NormalizeStorage.EXTRACTED_FILE_EXTENSION_LEN]).stem.split("_") - return (int(parts[-1]), parts[-2], parts[0]) \ No newline at end of file + parts = Path(file_name).stem.split(".") + if len(parts) != 3: + raise ValueError(file_name) + return TParsedNormalizeFileName(*parts) diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 9fd1faa65f..6ef8beeb32 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -7,7 +7,6 @@ from dlt.common.configuration.schema_volume_configuration import TSchemaFileFormat from dlt.common.file_storage import FileStorage from dlt.common.schema import Schema, verify_schema_hash -from dlt.common.schema.typing import TStoredSchema from dlt.common.typing import DictStrAny from dlt.common.configuration import SchemaVolumeConfiguration diff --git a/dlt/extract/extractor_storage.py b/dlt/extract/extractor_storage.py index 32e71f6fec..c116b2fb08 100644 --- a/dlt/extract/extractor_storage.py +++ b/dlt/extract/extractor_storage.py @@ -3,10 +3,8 @@ from dlt.common.json import json_typed_dumps from dlt.common.typing import Any from dlt.common.utils import uniq_id -from dlt.common.schema import normalize_schema_name from dlt.common.file_storage import FileStorage -from dlt.common.storages.versioned_storage import VersionedStorage -from dlt.common.storages.normalize_storage import NormalizeStorage +from dlt.common.storages import VersionedStorage, NormalizeStorage class ExtractorStorageBase(VersionedStorage): @@ -24,18 +22,18 @@ def save_json(self, name: str, d: Any) -> None: self.storage.save(name, json_typed_dumps(d)) def commit_events(self, schema_name: str, processed_file_path: str, dest_file_stem: str, no_processed_events: int, load_id: str, with_delete: bool = True) -> str: + raise NotImplementedError() # schema name cannot contain underscores - if schema_name != normalize_schema_name(schema_name): - raise ValueError(schema_name) + # FileStorage.validate_file_name_component(schema_name) - dest_name = NormalizeStorage.build_extracted_file_name(schema_name, dest_file_stem, no_processed_events, load_id) - # if no events extracted from tracker, file is not saved - if no_processed_events > 0: - # moves file to possibly external storage and place in the dest folder atomically - self.storage.copy_cross_storage_atomically( - self.normalize_storage.storage.storage_path, NormalizeStorage.EXTRACTED_FOLDER, processed_file_path, dest_name) + # dest_name = NormalizeStorage.build_extracted_file_stem(schema_name, dest_file_stem, no_processed_events, load_id) + # # if no events extracted from tracker, file is not saved + # if no_processed_events > 0: + # # moves file to possibly external storage and place in the dest folder atomically + # self.storage.copy_cross_storage_atomically( + # self.normalize_storage.storage.storage_path, NormalizeStorage.EXTRACTED_FOLDER, processed_file_path, dest_name) - if with_delete: - self.storage.delete(processed_file_path) + # if with_delete: + # self.storage.delete(processed_file_path) - return dest_name + # return dest_name diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index f5090a1f28..b716f5f613 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -1,7 +1,7 @@ from typing import Type from dlt.common.typing import StrAny -from dlt.common.dataset_writers import TLoaderFileFormat +from dlt.common.data_writers import TLoaderFileFormat from dlt.common.configuration import (PoolRunnerConfiguration, NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration, ProductionLoadVolumeConfiguration, ProductionNormalizeVolumeConfiguration, diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index f387229d05..3b4993bcf9 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -8,15 +8,13 @@ from dlt.cli import TRunnerArgs from dlt.common.runners import TRunMetrics, Runnable, run_pool, initialize_runner, workermethod from dlt.common.storages.exceptions import SchemaNotFoundError -from dlt.common.storages.normalize_storage import NormalizeStorage +from dlt.common.storages import NormalizeStorage, SchemaStorage, LoadStorage from dlt.common.telemetry import get_logging_extras from dlt.common.utils import uniq_id from dlt.common.typing import TDataItem from dlt.common.exceptions import PoolException -from dlt.common.storages import SchemaStorage from dlt.common.schema import TSchemaUpdate, Schema from dlt.common.schema.exceptions import CannotCoerceColumnException -from dlt.common.storages.load_storage import LoadStorage from dlt.normalize.configuration import configuration, NormalizeConfiguration @@ -177,7 +175,8 @@ def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files total_events = 0 for event_file in chain.from_iterable(chunk_files): # flatten chunks self.normalize_storage.storage.delete(event_file) - total_events += NormalizeStorage.get_events_count(event_file) + # TODO: get total events from worker function and make stats per table + # total_events += .... # log and update metrics logger.info(f"Chunk {load_id} processed") self.load_package_counter.labels(schema_name).inc() diff --git a/tests/common/storages/test_loader_storage.py b/tests/common/storages/test_loader_storage.py index a829f6207f..68c75e64e8 100644 --- a/tests/common/storages/test_loader_storage.py +++ b/tests/common/storages/test_loader_storage.py @@ -11,7 +11,7 @@ from dlt.common.typing import StrAny from dlt.common.utils import uniq_id -from tests.utils import TEST_STORAGE, write_version, autouse_root_storage +from tests.utils import TEST_STORAGE_ROOT, write_version, autouse_test_storage @pytest.fixture @@ -72,7 +72,7 @@ def test_save_load_schema(storage: LoadStorage) -> None: def test_job_elapsed_time_seconds(storage: LoadStorage) -> None: load_id, fn = start_loading_file(storage, "test file") - fp = storage.storage._make_path(storage._get_job_file_path(load_id, "started_jobs", fn)) + fp = storage.storage.make_full_path(storage._get_job_file_path(load_id, "started_jobs", fn)) elapsed = storage.job_elapsed_time_seconds(fp) sleep(0.3) # do not touch file diff --git a/tests/common/storages/test_normalize_storage.py b/tests/common/storages/test_normalize_storage.py index e1eb8552bd..88e0a28c3f 100644 --- a/tests/common/storages/test_normalize_storage.py +++ b/tests/common/storages/test_normalize_storage.py @@ -1,11 +1,12 @@ import pytest from dlt.common.storages.exceptions import NoMigrationPathException -from dlt.common.storages.normalize_storage import NormalizeStorage +from dlt.common.storages import NormalizeStorage from dlt.common.configuration import NormalizeVolumeConfiguration +from dlt.common.storages.normalize_storage import TParsedNormalizeFileName from dlt.common.utils import uniq_id -from tests.utils import write_version, autouse_root_storage +from tests.utils import write_version, autouse_test_storage @pytest.mark.skip() def test_load_events_and_group_by_sender() -> None: @@ -13,22 +14,15 @@ def test_load_events_and_group_by_sender() -> None: pass -@pytest.mark.skip() -def test_chunk_by_events() -> None: - # TODO: should distribute ~ N events evenly among m cores with fallback for small amounts of events - pass - - def test_build_extracted_file_name() -> None: load_id = uniq_id() - name = NormalizeStorage.build_extracted_file_name("event", "table", 121, load_id) + name = NormalizeStorage.build_extracted_file_stem("event", "table_with_parts__many", load_id) + ".jsonl" assert NormalizeStorage.get_schema_name(name) == "event" - assert NormalizeStorage.get_events_count(name) == 121 - assert NormalizeStorage._parse_extracted_file_name(name) == (121, load_id, "event") + assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("event", "table_with_parts__many", load_id) # empty schema should be supported - name = NormalizeStorage.build_extracted_file_name("", "table", 121, load_id) - assert NormalizeStorage._parse_extracted_file_name(name) == (121, load_id, "") + name = NormalizeStorage.build_extracted_file_stem("", "table", load_id) + ".jsonl" + assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("", "table", load_id) def test_full_migration_path() -> None: diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 28755c6acb..fe2d7d5d05 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -14,7 +14,7 @@ from dlt.common.storages import SchemaStorage, LiveSchemaStorage from dlt.common.typing import DictStrAny -from tests.utils import autouse_root_storage, TEST_STORAGE +from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT from tests.common.utils import load_yml_case, yml_case_path @@ -26,13 +26,13 @@ def storage() -> SchemaStorage: @pytest.fixture def synced_storage() -> SchemaStorage: # will be created in /schemas - return init_storage({"IMPORT_SCHEMA_PATH": TEST_STORAGE + "/import", "EXPORT_SCHEMA_PATH": TEST_STORAGE + "/import"}) + return init_storage({"IMPORT_SCHEMA_PATH": TEST_STORAGE_ROOT + "/import", "EXPORT_SCHEMA_PATH": TEST_STORAGE_ROOT + "/import"}) @pytest.fixture def ie_storage() -> SchemaStorage: # will be created in /schemas - return init_storage({"IMPORT_SCHEMA_PATH": TEST_STORAGE + "/import", "EXPORT_SCHEMA_PATH": TEST_STORAGE + "/export"}) + return init_storage({"IMPORT_SCHEMA_PATH": TEST_STORAGE_ROOT + "/import", "EXPORT_SCHEMA_PATH": TEST_STORAGE_ROOT + "/export"}) def init_storage(initial: DictStrAny = None) -> SchemaStorage: @@ -242,7 +242,7 @@ def test_save_store_schema(storage: SchemaStorage) -> None: def prepare_import_folder(storage: SchemaStorage) -> None: - shutil.copy(yml_case_path("schemas/eth/ethereum_schema_v4"), storage.storage._make_path("../import/ethereum_schema.yaml")) + shutil.copy(yml_case_path("schemas/eth/ethereum_schema_v4"), storage.storage.make_full_path("../import/ethereum_schema.yaml")) def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: diff --git a/tests/common/storages/test_versioned_storage.py b/tests/common/storages/test_versioned_storage.py index e4bcbf7a37..c7c2236cc5 100644 --- a/tests/common/storages/test_versioned_storage.py +++ b/tests/common/storages/test_versioned_storage.py @@ -5,7 +5,7 @@ from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException from dlt.common.storages.versioned_storage import VersionedStorage -from tests.utils import write_version, root_storage +from tests.utils import write_version, test_storage class MigratedStorage(VersionedStorage): @@ -19,41 +19,41 @@ def migrate_storage(self, from_version: semver.VersionInfo, to_version: semver.V self._save_version(from_version) -def test_new_versioned_storage(root_storage: FileStorage) -> None: - v = VersionedStorage("1.0.1", True, root_storage) +def test_new_versioned_storage(test_storage: FileStorage) -> None: + v = VersionedStorage("1.0.1", True, test_storage) assert v.version == "1.0.1" -def test_new_versioned_storage_non_owner(root_storage: FileStorage) -> None: +def test_new_versioned_storage_non_owner(test_storage: FileStorage) -> None: with pytest.raises(WrongStorageVersionException) as wsve: - VersionedStorage("1.0.1", False, root_storage) - assert wsve.value.storage_path == root_storage.storage_path + VersionedStorage("1.0.1", False, test_storage) + assert wsve.value.storage_path == test_storage.storage_path assert wsve.value.target_version == "1.0.1" assert wsve.value.initial_version == "0.0.0" -def test_migration(root_storage: FileStorage) -> None: - write_version(root_storage, "1.0.0") - v = MigratedStorage("1.2.0", True, root_storage) +def test_migration(test_storage: FileStorage) -> None: + write_version(test_storage, "1.0.0") + v = MigratedStorage("1.2.0", True, test_storage) assert v.version == "1.2.0" -def test_unknown_migration_path(root_storage: FileStorage) -> None: - write_version(root_storage, "1.0.0") +def test_unknown_migration_path(test_storage: FileStorage) -> None: + write_version(test_storage, "1.0.0") with pytest.raises(NoMigrationPathException) as wmpe: - MigratedStorage("1.3.0", True, root_storage) + MigratedStorage("1.3.0", True, test_storage) assert wmpe.value.migrated_version == "1.2.0" -def test_only_owner_migrates(root_storage: FileStorage) -> None: - write_version(root_storage, "1.0.0") +def test_only_owner_migrates(test_storage: FileStorage) -> None: + write_version(test_storage, "1.0.0") with pytest.raises(WrongStorageVersionException) as wmpe: - MigratedStorage("1.2.0", False, root_storage) + MigratedStorage("1.2.0", False, test_storage) assert wmpe.value.initial_version == "1.0.0" -def test_downgrade_not_possible(root_storage: FileStorage) -> None: - write_version(root_storage, "1.2.0") +def test_downgrade_not_possible(test_storage: FileStorage) -> None: + write_version(test_storage, "1.2.0") with pytest.raises(NoMigrationPathException) as wmpe: - MigratedStorage("1.1.0", True, root_storage) + MigratedStorage("1.1.0", True, test_storage) assert wmpe.value.migrated_version == "1.2.0" \ No newline at end of file diff --git a/tests/common/test_data_writers.py b/tests/common/test_data_writers.py index 439bae0748..456643f619 100644 --- a/tests/common/test_data_writers.py +++ b/tests/common/test_data_writers.py @@ -1,51 +1,55 @@ import io +import pytest +from typing import Iterator from dlt.common import pendulum -from dlt.common.dataset_writers import write_insert_values, escape_redshift_literal, escape_redshift_identifier, escape_bigquery_identifier +from dlt.common.data_writers import escape_redshift_literal, escape_redshift_identifier, escape_bigquery_identifier +from dlt.common.data_writers.writers import DataWriter, InsertValuesWriter -from tests.common.utils import load_json_case +from tests.common.utils import load_json_case, row_to_column_schemas -def test_simple_insert_writer() -> None: - rows = load_json_case("simple_row") +@pytest.fixture +def insert_writer() -> Iterator[DataWriter]: with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + yield InsertValuesWriter(f) + + +def test_simple_insert_writer(insert_writer: DataWriter) -> None: + rows = load_json_case("simple_row") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[0].startswith("INSERT INTO {}") assert '","'.join(rows[0].keys()) in lines[0] assert lines[1] == "VALUES" assert len(lines) == 4 -def test_bytes_insert_writer() -> None: +def test_bytes_insert_writer(insert_writer: DataWriter) -> None: rows = [{"bytes": b"bytes"}] - with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[2] == "(from_hex('6279746573'));" -def test_datetime_insert_writer() -> None: +def test_datetime_insert_writer(insert_writer: DataWriter) -> None: rows = [{"datetime": pendulum.from_timestamp(1658928602.575267)}] - with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[2] == "('2022-07-27T13:30:02.575267+00:00');" -def test_date_insert_writer() -> None: +def test_date_insert_writer(insert_writer: DataWriter) -> None: rows = [{"date": pendulum.date(1974, 8, 11)}] - with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[2] == "('1974-08-11');" -def test_unicode_insert_writer() -> None: +def test_unicode_insert_writer(insert_writer: DataWriter) -> None: rows = load_json_case("weird_rows") - with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[2].endswith("', NULL''); DROP SCHEMA Public --'),") assert lines[3].endswith("'イロハニホヘト チリヌルヲ ''ワカヨタレソ ツネナラム'),") assert lines[4].endswith("'ऄअआइ''ईउऊऋऌऍऎए'),") diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 703baf628f..5edb12244e 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -11,15 +11,13 @@ from dlt.common.typing import StrAny from dlt.common.file_storage import FileStorage from dlt.common.schema import TDataType -from dlt.common.storages.load_storage import LoadStorage -from dlt.common.storages.normalize_storage import NormalizeStorage -from dlt.common.storages import SchemaStorage +from dlt.common.storages import NormalizeStorage, LoadStorage from dlt.extract.extractor_storage import ExtractorStorageBase from dlt.normalize import Normalize, configuration as normalize_configuration, __version__ from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES -from tests.utils import TEST_STORAGE, assert_no_dict_key_starts_with, write_version, clean_storage, init_logger +from tests.utils import TEST_STORAGE_ROOT, assert_no_dict_key_starts_with, write_version, clean_test_storage, init_logger from tests.normalize.utils import json_case_path @@ -38,7 +36,7 @@ def rasa_normalize() -> Normalize: def init_normalize(default_schemas_path: str = None) -> Normalize: - clean_storage() + clean_test_storage() initial = {} if default_schemas_path: initial = {"IMPORT_SCHEMA_PATH": default_schemas_path, "EXTERNAL_SCHEMA_FORMAT": "json"} @@ -225,12 +223,12 @@ def test_normalize_typed_json(raw_normalize: Normalize) -> None: def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str) -> None: - extractor = ExtractorStorageBase("1.0.0", True, FileStorage(os.path.join(TEST_STORAGE, "extractor"), makedirs=True), normalize_storage) + extractor = ExtractorStorageBase("1.0.0", True, FileStorage(os.path.join(TEST_STORAGE_ROOT, "extractor"), makedirs=True), normalize_storage) load_id = uniq_id() extractor.save_json(f"{load_id}.json", items) extractor.commit_events( schema_name, - extractor.storage._make_path(f"{load_id}.json"), + extractor.storage.make_full_path(f"{load_id}.json"), "items", len(items), load_id @@ -257,7 +255,7 @@ def normalize_cases(normalize: Normalize, cases: Sequence[str]) -> str: def copy_cases(normalize_storage: NormalizeStorage, cases: Sequence[str]) -> None: for case in cases: event_user_path = json_case_path(f"{case}.extracted") - shutil.copy(event_user_path, normalize_storage.storage._make_path(NormalizeStorage.EXTRACTED_FOLDER)) + shutil.copy(event_user_path, normalize_storage.storage.make_full_path(NormalizeStorage.EXTRACTED_FOLDER)) def expect_load_package(load_storage: LoadStorage, load_id: str, expected_tables: Sequence[str]) -> Dict[str, str]: @@ -266,7 +264,7 @@ def expect_load_package(load_storage: LoadStorage, load_id: str, expected_tables ofl: Dict[str, str] = {} for expected_table in expected_tables: # find all files for particular table, ignoring file id - file_mask = load_storage.build_job_file_name(expected_table, "*") + file_mask = load_storage.build_job_file_name(expected_table, "*", validate_components=False) # files are in normalized//new_jobs file_path = load_storage._get_job_file_path(load_id, "new_jobs", file_mask) candidates = [f for f in files if fnmatch(f, file_path)] From 059e77d4ed8f24ae3e5e854b9307ed8258f5be4d Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 20 Sep 2022 22:37:21 +0200 Subject: [PATCH 09/66] makes schema name validation part of name normalization, function to normalize all names in tables --- dlt/common/normalizers/json/relational.py | 1 + dlt/common/normalizers/names/snake_case.py | 8 +++ dlt/common/schema/__init__.py | 2 +- dlt/common/schema/detections.py | 1 + dlt/common/schema/exceptions.py | 2 +- dlt/common/schema/schema.py | 58 ++++++++++++---------- tests/common/schema/custom_normalizers.py | 4 ++ tests/common/schema/test_schema.py | 28 +++++------ 8 files changed, 61 insertions(+), 43 deletions(-) diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 5b1b7c99f0..2c8a5864a5 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -67,6 +67,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, parent_name: Optional[str]) - # for lists and dicts we must check if type is possibly complex if isinstance(v, (dict, list)): if not _is_complex_type(schema, table, child_name, __r_lvl): + # TODO: if schema contains table {table}__{child_name} then convert v into single element list if isinstance(v, dict): # flatten the dict more norm_row_dicts(v, __r_lvl + 1, parent_name=child_name) diff --git a/dlt/common/normalizers/names/snake_case.py b/dlt/common/normalizers/names/snake_case.py index 27d1629966..639d91f089 100644 --- a/dlt/common/normalizers/names/snake_case.py +++ b/dlt/common/normalizers/names/snake_case.py @@ -1,5 +1,6 @@ import re from typing import Any, Sequence +from functools import lru_cache RE_UNDERSCORES = re.compile("_+") @@ -15,6 +16,7 @@ # fix a name so it's acceptable as database table name +@lru_cache(maxsize=None) def normalize_table_name(name: str) -> str: if not name: raise ValueError(name) @@ -34,11 +36,17 @@ def camel_to_snake(name: str) -> str: # fix a name so it's an acceptable name for a database column +@lru_cache(maxsize=None) def normalize_column_name(name: str) -> str: # replace consecutive underscores with single one to prevent name clashes with PATH_SEPARATOR return RE_UNDERSCORES.sub("_", normalize_table_name(name)) +# fix a name so it is acceptable as schema name +def normalize_schema_name(name: str) -> str: + return normalize_column_name(name) + + # build full db dataset (dataset) name out of (normalized) default dataset and schema name def normalize_make_dataset_name(default_dataset: str, default_schema_name: str, schema_name: str) -> str: if schema_name is None: diff --git a/dlt/common/schema/__init__.py b/dlt/common/schema/__init__.py index b3a95af283..ebbe4dfcaa 100644 --- a/dlt/common/schema/__init__.py +++ b/dlt/common/schema/__init__.py @@ -1,4 +1,4 @@ from dlt.common.schema.typing import TSchemaUpdate, TStoredSchema, TTableSchemaColumns, TDataType, THintType, TColumnSchema, TColumnSchemaBase # noqa: F401 from dlt.common.schema.typing import COLUMN_HINTS # noqa: F401 from dlt.common.schema.schema import Schema # noqa: F401 -from dlt.common.schema.utils import normalize_schema_name, add_missing_hints, verify_schema_hash # noqa: F401 +from dlt.common.schema.utils import add_missing_hints, verify_schema_hash # noqa: F401 diff --git a/dlt/common/schema/detections.py b/dlt/common/schema/detections.py index 697251de22..49acabf97b 100644 --- a/dlt/common/schema/detections.py +++ b/dlt/common/schema/detections.py @@ -27,6 +27,7 @@ def is_iso_timestamp(t: Type[Any], v: Any) -> Optional[TDataType]: return None # strict autodetection of iso timestamps try: + # TODO: use same functions as in coercions dt = pendulum.parse(v, strict=True, exact=True) if isinstance(dt, datetime.datetime): return "timestamp" diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 57e43b81cd..cf335b5763 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -11,7 +11,7 @@ class SchemaException(DltException): class InvalidSchemaName(SchemaException): def __init__(self, name: str, normalized_name: str) -> None: self.name = name - super().__init__(f"{name} is invalid schema name. Only lowercase letters are allowed. Try {normalized_name} instead") + super().__init__(f"{name} is invalid schema name. Try {normalized_name} instead") class CannotCoerceColumnException(SchemaException): diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index db7d14f2a6..e69a717b65 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -12,7 +12,7 @@ THintType, TWriteDisposition) from dlt.common.schema import utils from dlt.common.schema.exceptions import (CannotCoerceColumnException, CannotCoerceNullException, InvalidSchemaName, - ParentTableNotFoundException, SchemaCorruptedException, TablePropertiesConflictException) + ParentTableNotFoundException, SchemaCorruptedException) from dlt.common.validation import validate_dict @@ -24,9 +24,6 @@ class Schema: ENGINE_VERSION = 4 def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: - # verify schema name - if name != utils.normalize_schema_name(name): - raise InvalidSchemaName(name, utils.normalize_schema_name(name)) self._schema_tables: TSchemaTables = {} self._schema_name: str = name self._stored_version = 1 # version at load/creation time @@ -62,6 +59,8 @@ def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: self._add_standard_hints() # configure normalizers, including custom config if present self._configure_normalizers() + # verify schema name after configuring normalizers + self._verify_schema_name(name) # compile all known regexes self._compile_regexes() # set initial version hash @@ -189,27 +188,14 @@ def update_schema(self, partial_table: TPartialTableSchema) -> None: # add the whole new table to SchemaTables self._schema_tables[table_name] = partial_table else: - # check if table properties can be merged - if table.get("parent") != partial_table.get("parent"): - raise TablePropertiesConflictException(table_name, "parent", table.get("parent"), partial_table.get("parent")) - # check if partial table has write disposition set - partial_w_d = partial_table.get("write_disposition") - if partial_w_d: - # get write disposition recursively for existing table - existing_w_d = self.get_write_disposition(table_name) - if existing_w_d != partial_w_d: - raise TablePropertiesConflictException(table_name, "write_disposition", existing_w_d, partial_w_d) - # add several columns to existing table - table_columns = table["columns"] - for column in partial_table["columns"].values(): - column_name = column["name"] - if column_name in table_columns: - # we do not support changing existing columns - if not utils.compare_columns(table_columns[column_name], column): - # attempt to update to incompatible columns - raise CannotCoerceColumnException(table_name, column_name, column["data_type"], table_columns[column_name]["data_type"], None) - else: - table_columns[column_name] = column + # partial_w_d = partial_table.get("write_disposition") + # if table.get("parent") and not table.get("write_disposition") and partial_w_d: + # # get write disposition recursively for existing table and check if those fit + # existing_w_d = self.get_write_disposition(table_name) + # if existing_w_d != partial_table.get("write_disposition"): + # raise TablePropertiesConflictException(table_name, "write_disposition", existing_w_d, partial_w_d) + # merge tables performing additional checks + utils.merge_tables(table, partial_table) def bump_version(self) -> Tuple[int, str]: """Computes schema hash in order to check if schema content was modified. In such case the schema ``stored_version`` and ``stored_version_hash`` are updated. @@ -256,7 +242,21 @@ def merge_hints(self, new_hints: Mapping[THintType, Sequence[TSimpleRegex]]) -> default_hints[h] = l # type: ignore self._compile_regexes() - def get_schema_update_for(self, table_name: str, t: TTableSchemaColumns) -> List[TColumnSchema]: + def normalize_table_identifiers(self, table: TTableSchema) -> TTableSchema: + # normalize all identifiers in table according to name normalizer of the schema + table["name"] = self.normalize_table_name(table["name"]) + parent = table.get("parent") + if parent: + table["parent"] = self.normalize_table_name(parent) + columns = table.get("columns") + if columns: + for c in columns.values(): + c["name"] = self.normalize_column_name(c["name"]) + # re-index columns as the name changed + table["columns"] = {c["name"]:c for c in columns.values()} + return table + + def get_new_columns(self, table_name: str, t: TTableSchemaColumns) -> List[TColumnSchema]: # gets new columns to be added to "t" to bring up to date with stored schema diff_c: List[TColumnSchema] = [] s_t = self.get_table_columns(table_name) @@ -431,7 +431,7 @@ def _configure_normalizers(self) -> None: # name normalization functions self.normalize_table_name = naming_module.normalize_table_name self.normalize_column_name = naming_module.normalize_column_name - self.normalize_schema_name = utils.normalize_schema_name + self.normalize_schema_name = naming_module.normalize_schema_name self.normalize_make_dataset_name = naming_module.normalize_make_dataset_name self.normalize_make_path = naming_module.normalize_make_path self.normalize_break_path = naming_module.normalize_break_path @@ -439,6 +439,10 @@ def _configure_normalizers(self) -> None: self.normalize_data_item = json_module.normalize_data_item json_module.extend_schema(self) + def _verify_schema_name(self, name: str) -> None: + if name != self.normalize_schema_name(name): + raise InvalidSchemaName(name, self.normalize_schema_name(name)) + def _compile_regexes(self) -> None: if self._settings: for pattern, dt in self._settings.get("preferred_types", {}).items(): diff --git a/tests/common/schema/custom_normalizers.py b/tests/common/schema/custom_normalizers.py index 799d7a637e..5709dd8db2 100644 --- a/tests/common/schema/custom_normalizers.py +++ b/tests/common/schema/custom_normalizers.py @@ -12,6 +12,10 @@ def normalize_column_name(name: str) -> str: return "column_" + name.lower() +def normalize_schema_name(name: str) -> str: + return name.lower() + + def extend_schema(schema: Schema) -> None: json_config = schema._normalizers_config["json"]["config"] d_h = schema._settings.setdefault("default_hints", {}) diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 07297cf270..d664b14cf4 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -11,7 +11,7 @@ from dlt.common.schema.exceptions import InvalidSchemaName, ParentTableNotFoundException, SchemaEngineNoUpgradePathException from dlt.common.storages import SchemaStorage -from tests.utils import autouse_root_storage +from tests.utils import autouse_test_storage from tests.common.utils import load_json_case, load_yml_case SCHEMA_NAME = "event" @@ -54,9 +54,9 @@ def cn_schema() -> Schema: def test_normalize_schema_name(schema: Schema) -> None: - assert schema.normalize_schema_name("BAN_ANA") == "banana" - assert schema.normalize_schema_name("event-.!:value") == "eventvalue" - assert schema.normalize_schema_name("123event-.!:value") == "s123eventvalue" + assert schema.normalize_schema_name("BAN_ANA") == "ban_ana" + assert schema.normalize_schema_name("event-.!:value") == "event_value" + assert schema.normalize_schema_name("123event-.!:value") == "_123event_value" with pytest.raises(ValueError): assert schema.normalize_schema_name("") with pytest.raises(ValueError): @@ -117,8 +117,8 @@ def test_column_name_validator(schema: Schema) -> None: def test_invalid_schema_name() -> None: with pytest.raises(InvalidSchemaName) as exc: - Schema("a_b") - assert exc.value.name == "a_b" + Schema("a!b") + assert exc.value.name == "a!b" @pytest.mark.parametrize("columns,hint,value", [ @@ -143,7 +143,7 @@ def test_save_store_schema(schema: Schema, schema_storage: SchemaStorage) -> Non assert not schema_storage.storage.has_file(EXPECTED_FILE_NAME) saved_file_name = schema_storage.save_schema(schema) # return absolute path - assert saved_file_name == schema_storage.storage._make_path(EXPECTED_FILE_NAME) + assert saved_file_name == schema_storage.storage.make_full_path(EXPECTED_FILE_NAME) assert schema_storage.storage.has_file(EXPECTED_FILE_NAME) schema_copy = schema_storage.load_schema("event") assert schema.name == schema_copy.name @@ -369,16 +369,16 @@ def test_compare_columns() -> None: ]) # columns identical with self for c in table["columns"].values(): - assert utils.compare_columns(c, c) is True - assert utils.compare_columns(table["columns"]["col3"], table["columns"]["col4"]) is True + assert utils.compare_column(c, c) is True + assert utils.compare_column(table["columns"]["col3"], table["columns"]["col4"]) is True # data type may not differ - assert utils.compare_columns(table["columns"]["col1"], table["columns"]["col3"]) is False + assert utils.compare_column(table["columns"]["col1"], table["columns"]["col3"]) is False # nullability may not differ - assert utils.compare_columns(table["columns"]["col1"], table["columns"]["col2"]) is False + assert utils.compare_column(table["columns"]["col1"], table["columns"]["col2"]) is False # any of the hints may differ for hint in COLUMN_HINTS: table["columns"]["col3"][hint] = True - assert utils.compare_columns(table["columns"]["col3"], table["columns"]["col4"]) is True + assert utils.compare_column(table["columns"]["col3"], table["columns"]["col4"]) is True def assert_new_schema_values_custom_normalizers(schema: Schema) -> None: @@ -390,7 +390,7 @@ def assert_new_schema_values_custom_normalizers(schema: Schema) -> None: # call normalizers assert schema.normalize_column_name("a") == "column_a" assert schema.normalize_table_name("a__b") == "A__b" - assert schema.normalize_schema_name("1A_b") == "s1ab" + assert schema.normalize_schema_name("1A_b") == "1a_b" # assumes elements are normalized assert schema.normalize_make_path("A", "B", "!C") == "A__B__!C" assert schema.normalize_break_path("A__B__!C") == ["A", "B", "!C"] @@ -413,7 +413,7 @@ def assert_new_schema_values(schema: Schema) -> None: # call normalizers assert schema.normalize_column_name("A") == "a" assert schema.normalize_table_name("A__B") == "a__b" - assert schema.normalize_schema_name("1A_b") == "s1ab" + assert schema.normalize_schema_name("1A_b") == "_1_a_b" # assumes elements are normalized assert schema.normalize_make_path("A", "B", "!C") == "A__B__!C" assert schema.normalize_break_path("A__B__!C") == ["A", "B", "!C"] From e9b70dddb177bde608bf175b37f6e224a82bdccf Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 20 Sep 2022 22:37:34 +0200 Subject: [PATCH 10/66] several name cleanips --- Makefile | 2 +- dlt/dbt_runner/runner.py | 2 +- tests/common/utils.py | 11 ++++++++ tests/dbt_runner/test_runner_redshift.py | 4 +-- tests/dbt_runner/test_utils.py | 36 ++++++++++++------------ tests/dbt_runner/utils.py | 4 +-- tests/tools/create_storages.py | 8 ++---- tests/utils.py | 22 +++++++-------- 8 files changed, 48 insertions(+), 41 deletions(-) diff --git a/Makefile b/Makefile index ef0f74c959..fb1ca191d1 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ dev: has-poetry lint: ./check-package.sh - poetry run mypy --config-file mypy.ini dlt examples + poetry run mypy --config-file mypy.ini dlt examples experiments/pipeline poetry run flake8 --max-line-length=200 examples dlt poetry run flake8 --max-line-length=200 tests # dlt/pipeline dlt/common/schema dlt/common/normalizers diff --git a/dlt/dbt_runner/runner.py b/dlt/dbt_runner/runner.py index 5ad572fe54..ac186ab59c 100644 --- a/dlt/dbt_runner/runner.py +++ b/dlt/dbt_runner/runner.py @@ -45,7 +45,7 @@ def create_folders() -> Tuple[FileStorage, StrAny, Sequence[str], str, str]: global_args = initialize_dbt_logging(CONFIG.LOG_LEVEL, is_json_logging(CONFIG.LOG_FORMAT)) # generate path for the dbt package repo - repo_path = storage._make_path(CLONED_PACKAGE_NAME) + repo_path = storage.make_full_path(CLONED_PACKAGE_NAME) # generate profile name profile_name: str = None diff --git a/tests/common/utils.py b/tests/common/utils.py index e85b4649f3..ef423214ae 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -2,6 +2,9 @@ from typing import Mapping, cast from dlt.common import json +from dlt.common.typing import StrAny +from dlt.common.schema import utils +from dlt.common.schema.typing import TTableSchemaColumns def load_json_case(name: str) -> Mapping: @@ -20,3 +23,11 @@ def json_case_path(name: str) -> str: def yml_case_path(name: str) -> str: return f"./tests/common/cases/{name}.yml" + + +def row_to_column_schemas(row: StrAny) -> TTableSchemaColumns: + return {k: utils.add_missing_hints({ + "name": k, + "data_type": "text", + "nullable": False + }) for k in row.keys()} diff --git a/tests/dbt_runner/test_runner_redshift.py b/tests/dbt_runner/test_runner_redshift.py index a7cc1ac8b6..d2fc08743e 100644 --- a/tests/dbt_runner/test_runner_redshift.py +++ b/tests/dbt_runner/test_runner_redshift.py @@ -16,7 +16,7 @@ from dlt.dbt_runner import runner from dlt.load.redshift.client import RedshiftSqlClient -from tests.utils import add_config_to_env, clean_storage, init_logger, preserve_environ +from tests.utils import add_config_to_env, clean_test_storage, init_logger, preserve_environ from tests.dbt_runner.utils import modify_and_commit_file, load_secret, setup_runner DEST_SCHEMA_PREFIX = "test_" + uniq_id() @@ -178,7 +178,7 @@ def test_dbt_incremental_schema_out_of_sync_error() -> None: def get_runner() -> FileStorage: - clean_storage() + clean_test_storage() runner.storage, runner.dbt_package_vars, runner.global_args, runner.repo_path, runner.profile_name = runner.create_folders() runner.model_elapsed_gauge, runner.model_exec_info = runner.create_gauges(CollectorRegistry(auto_describe=True)) return runner.storage diff --git a/tests/dbt_runner/test_utils.py b/tests/dbt_runner/test_utils.py index 1e849246d6..2ea3191930 100644 --- a/tests/dbt_runner/test_utils.py +++ b/tests/dbt_runner/test_utils.py @@ -7,7 +7,7 @@ from dlt.dbt_runner.utils import DBTProcessingError, clone_repo, ensure_remote_head, git_custom_key_command, initialize_dbt_logging, run_dbt_command -from tests.utils import root_storage +from tests.utils import test_storage from tests.dbt_runner.utils import load_secret, modify_and_commit_file, restore_secret_storage_path @@ -32,60 +32,60 @@ def test_no_ssh_key_context() -> None: assert git_command == 'ssh -o "StrictHostKeyChecking accept-new"' -def test_clone(root_storage: FileStorage) -> None: - repo_path = root_storage._make_path("awesome_repo") +def test_clone(test_storage: FileStorage) -> None: + repo_path = test_storage.make_full_path("awesome_repo") # clone a small public repo clone_repo(AWESOME_REPO, repo_path, with_git_command=None) - assert root_storage.has_folder("awesome_repo") + assert test_storage.has_folder("awesome_repo") # make sure directory clean ensure_remote_head(repo_path, with_git_command=None) -def test_clone_with_commit_id(root_storage: FileStorage) -> None: - repo_path = root_storage._make_path("awesome_repo") +def test_clone_with_commit_id(test_storage: FileStorage) -> None: + repo_path = test_storage.make_full_path("awesome_repo") # clone a small public repo clone_repo(AWESOME_REPO, repo_path, with_git_command=None, branch="7f88000be2d4f265c83465fec4b0b3613af347dd") - assert root_storage.has_folder("awesome_repo") + assert test_storage.has_folder("awesome_repo") ensure_remote_head(repo_path, with_git_command=None) -def test_clone_with_wrong_branch(root_storage: FileStorage) -> None: - repo_path = root_storage._make_path("awesome_repo") +def test_clone_with_wrong_branch(test_storage: FileStorage) -> None: + repo_path = test_storage.make_full_path("awesome_repo") # clone a small public repo with pytest.raises(GitCommandError): clone_repo(AWESOME_REPO, repo_path, with_git_command=None, branch="wrong_branch") -def test_clone_with_deploy_key_access_denied(root_storage: FileStorage) -> None: +def test_clone_with_deploy_key_access_denied(test_storage: FileStorage) -> None: secret = load_secret("deploy_key") - repo_path = root_storage._make_path("private_repo") + repo_path = test_storage.make_full_path("private_repo") with git_custom_key_command(secret) as git_command: with pytest.raises(GitCommandError): clone_repo(PRIVATE_REPO, repo_path, with_git_command=git_command) -def test_clone_with_deploy_key(root_storage: FileStorage) -> None: +def test_clone_with_deploy_key(test_storage: FileStorage) -> None: secret = load_secret("deploy_key") - repo_path = root_storage._make_path("private_repo_access") + repo_path = test_storage.make_full_path("private_repo_access") with git_custom_key_command(secret) as git_command: clone_repo(PRIVATE_REPO_WITH_ACCESS, repo_path, with_git_command=git_command) ensure_remote_head(repo_path, with_git_command=git_command) -def test_repo_status_update(root_storage: FileStorage) -> None: +def test_repo_status_update(test_storage: FileStorage) -> None: secret = load_secret("deploy_key") - repo_path = root_storage._make_path("private_repo_access") + repo_path = test_storage.make_full_path("private_repo_access") with git_custom_key_command(secret) as git_command: clone_repo(PRIVATE_REPO_WITH_ACCESS, repo_path, with_git_command=git_command) # modify README.md readme_path = modify_and_commit_file(repo_path, "README.md") - assert root_storage.has_file(readme_path) + assert test_storage.has_file(readme_path) with pytest.raises(RepositoryDirtyError): ensure_remote_head(repo_path, with_git_command=git_command) -def test_dbt_commands(root_storage: FileStorage) -> None: - repo_path = root_storage._make_path("jaffle_shop") +def test_dbt_commands(test_storage: FileStorage) -> None: + repo_path = test_storage.make_full_path("jaffle_shop") # clone jaffle shop for dbt 1.0.0 clone_repo(JAFFLE_SHOP_REPO, repo_path, with_git_command=None, branch="core-v1.0.0") # copy profile diff --git a/tests/dbt_runner/utils.py b/tests/dbt_runner/utils.py index 2484fbdb9b..304984b93f 100644 --- a/tests/dbt_runner/utils.py +++ b/tests/dbt_runner/utils.py @@ -8,7 +8,7 @@ from dlt.dbt_runner.configuration import gen_configuration_variant from dlt.dbt_runner import runner -from tests.utils import clean_storage +from tests.utils import clean_test_storage SECRET_STORAGE_PATH = environ.SECRET_STORAGE_PATH @@ -45,7 +45,7 @@ def modify_and_commit_file(repo_path: str, file_name: str, content: str = "NEW R def setup_runner(dest_schema_prefix: str, override_values: StrAny = None) -> None: - clean_storage() + clean_test_storage() C = gen_configuration_variant(initial_values=override_values) # set unique dest schema prefix by default C.DEST_SCHEMA_PREFIX = dest_schema_prefix diff --git a/tests/tools/create_storages.py b/tests/tools/create_storages.py index 7606f433be..f2ae8d2260 100644 --- a/tests/tools/create_storages.py +++ b/tests/tools/create_storages.py @@ -1,9 +1,5 @@ -from dlt.common.storages.normalize_storage import NormalizeStorage -from dlt.common.configuration import NormalizeVolumeConfiguration -from dlt.common.storages.load_storage import LoadStorage -from dlt.common.configuration import LoadVolumeConfiguration -from dlt.common.storages.schema_storage import SchemaStorage -from dlt.common.configuration import SchemaVolumeConfiguration +from dlt.common.storages import NormalizeStorage, LoadStorage, SchemaStorage +from dlt.common.configuration import NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration NormalizeStorage(True, NormalizeVolumeConfiguration) diff --git a/tests/utils.py b/tests/utils.py index be4d1b544c..aa697fe481 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,7 +15,7 @@ from dlt.common.typing import StrAny -TEST_STORAGE = "_storage" +TEST_STORAGE_ROOT = "_storage" class MockHttpResponse(): @@ -31,20 +31,20 @@ def write_version(storage: FileStorage, version: str) -> None: storage.save(VersionedStorage.VERSION_FILE, str(version)) -def delete_storage() -> None: - storage = FileStorage(TEST_STORAGE) +def delete_test_storage() -> None: + storage = FileStorage(TEST_STORAGE_ROOT) if storage.has_folder(""): storage.delete_folder("", recursively=True) @pytest.fixture() -def root_storage() -> FileStorage: - return clean_storage() +def test_storage() -> FileStorage: + return clean_test_storage() @pytest.fixture(autouse=True) -def autouse_root_storage() -> FileStorage: - return clean_storage() +def autouse_test_storage() -> FileStorage: + return clean_test_storage() @pytest.fixture(scope="module", autouse=True) @@ -62,16 +62,16 @@ def init_logger(C: Type[RunConfiguration] = None) -> None: init_logging_from_config(C) -def clean_storage(init_normalize: bool = False, init_loader: bool = False) -> FileStorage: - storage = FileStorage(TEST_STORAGE, "t", makedirs=True) +def clean_test_storage(init_normalize: bool = False, init_loader: bool = False) -> FileStorage: + storage = FileStorage(TEST_STORAGE_ROOT, "t", makedirs=True) storage.delete_folder("", recursively=True) storage.create_folder(".") if init_normalize: - from dlt.common.storages.normalize_storage import NormalizeStorage + from dlt.common.storages import NormalizeStorage from dlt.common.configuration import NormalizeVolumeConfiguration NormalizeStorage(True, NormalizeVolumeConfiguration) if init_loader: - from dlt.common.storages.load_storage import LoadStorage + from dlt.common.storages import LoadStorage from dlt.common.configuration import LoadVolumeConfiguration LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) return storage From fcaaef662df8411bc81c966993fc13c2e5f26559 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Sep 2022 17:45:48 +0200 Subject: [PATCH 11/66] adds close data writer operation and written items count --- dlt/common/data_writers/buffered.py | 40 +++++++++++++++++++++------ dlt/common/data_writers/exceptions.py | 6 ++++ dlt/common/data_writers/writers.py | 8 ++++-- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 26ca1c28f9..f27db9a65a 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -4,32 +4,36 @@ from dlt.common.typing import TDataItem from dlt.common.sources import TDirectDataItem from dlt.common.data_writers import TLoaderFileFormat -from dlt.common.data_writers.exceptions import InvalidFileNameTemplateException +from dlt.common.data_writers.exceptions import BufferedDataWriterClosed, InvalidFileNameTemplateException from dlt.common.data_writers.writers import DataWriter from dlt.common.schema.typing import TTableSchemaColumns class BufferedDataWriter: - def __init__(self, file_format: TLoaderFileFormat, file_name_template: str, buffer_max_items: int = 5000, file_max_bytes: int = None): + def __init__(self, file_format: TLoaderFileFormat, file_name_template: str, buffer_max_items: int = 5000, file_max_items: int = None, file_max_bytes: int = None): self.file_format = file_format self._file_format_spec = DataWriter.data_format_from_file_format(self.file_format) # validate if template has correct placeholders self.file_name_template = file_name_template self.all_files: List[str] = [] - self.buffer_max_items = buffer_max_items + # buffered items must be less than max items in file + self.buffer_max_items = min(buffer_max_items, file_max_items or buffer_max_items) self.file_max_bytes = file_max_bytes + self.file_max_items = file_max_items self._current_columns: TTableSchemaColumns = None self._file_name: str = None self._buffered_items: List[TDataItem] = [] self._writer: DataWriter = None self._file: IO[Any] = None + self._closed = False try: self._rotate_file() except TypeError: raise InvalidFileNameTemplateException(file_name_template) def write_data_item(self, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: + self._ensure_open() # rotate file if columns changed and writer does not allow for that # as the only allowed change is to add new column (no updates/deletes), we detect the change by comparing lengths if self._writer and not self._writer.data_format().supports_schema_changes and len(columns) != len(self._current_columns): @@ -42,15 +46,29 @@ def write_data_item(self, item: TDirectDataItem, columns: TTableSchemaColumns) - else: self._buffered_items.append(item) # flush if max buffer exceeded - if len(self._buffered_items) > self.buffer_max_items: + if len(self._buffered_items) >= self.buffer_max_items: self._flush_items() # rotate the file if max_bytes exceeded - if self.file_max_bytes and self._file and self._file.tell() > self.file_max_bytes: - self._rotate_file() + if self._file: + # rotate on max file size + if self.file_max_bytes and self._file.tell() >= self.file_max_bytes: + self._rotate_file() + # rotate on max items + if self.file_max_items and self._writer.items_count >= self.file_max_items: + self._rotate_file() + + def close_writer(self) -> None: + self._ensure_open() + self._flush_and_close_file() + self._closed = True + + @property + def closed(self) -> bool: + return self._closed def _rotate_file(self) -> None: - self.close_writer() - self._file_name = self.file_name_template % uniq_id() + "." + self._file_format_spec.file_extension + self._flush_and_close_file() + self._file_name = self.file_name_template % uniq_id(5) + "." + self._file_format_spec.file_extension def _flush_items(self) -> None: if len(self._buffered_items) > 0: @@ -67,7 +85,7 @@ def _flush_items(self) -> None: self._writer.write_data(self._buffered_items) self._buffered_items.clear() - def close_writer(self) -> None: + def _flush_and_close_file(self) -> None: # if any buffered items exist, flush them self._flush_items() # if writer exists then close it @@ -79,3 +97,7 @@ def close_writer(self) -> None: self.all_files.append(self._file_name) self._writer = None self._file = None + + def _ensure_open(self) -> None: + if self._closed: + raise BufferedDataWriterClosed(self._file_name) \ No newline at end of file diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py index 4f249eb142..ffba6a49ba 100644 --- a/dlt/common/data_writers/exceptions.py +++ b/dlt/common/data_writers/exceptions.py @@ -9,3 +9,9 @@ class InvalidFileNameTemplateException(DataWriterException, ValueError): def __init__(self, file_name_template: str): self.file_name_template = file_name_template super().__init__(f"Wrong file name template {file_name_template}. File name template must contain exactly one %s formatter") + + +class BufferedDataWriterClosed(DataWriterException): + def __init__(self, file_name: str): + self.file_name = file_name + super().__init__(f"Writer with recent file name {file_name} is already closed") diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index cc1a8fe212..d2d788b64e 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -24,14 +24,14 @@ class TFileFormatSpec: class DataWriter(abc.ABC): def __init__(self, f: IO[Any]) -> None: self._f = f + self.items_count = 0 @abc.abstractmethod def write_header(self, columns_schema: TTableSchemaColumns) -> None: pass - @abc.abstractmethod def write_data(self, rows: Sequence[Any]) -> None: - pass + self.items_count += len(rows) @abc.abstractmethod def write_footer(self) -> None: @@ -74,6 +74,7 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: pass def write_data(self, rows: Sequence[Any]) -> None: + super().write_data(rows) # use jsonl to write load files https://jsonlines.org/ with jsonlines.Writer(self._f, dumps=json.dumps) as w: w.write_all(rows) @@ -89,6 +90,8 @@ def data_format(cls) -> TFileFormatSpec: class JsonlPUAEncodeWriter(JsonlWriter): def write_data(self, rows: Sequence[Any]) -> None: + # skip JsonlWriter when calling super + super(JsonlWriter, self).write_data(rows) # encode types with PUA characters with jsonlines.Writer(self._f, dumps=json_typed_dumps) as w: w.write_all(rows) @@ -116,6 +119,7 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: self._f.write(")\nVALUES\n") def write_data(self, rows: Sequence[Any]) -> None: + super().write_data(rows) def stringify(v: Any) -> str: if isinstance(v, bytes): From 40c5acce9b81df8f1c5fd14e541f003b97e07721 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Sep 2022 17:47:50 +0200 Subject: [PATCH 12/66] adds explicit table name to json normalizer func signature --- dlt/common/normalizers/json/__init__.py | 11 +++++-- dlt/common/normalizers/json/relational.py | 33 +++++++++---------- .../normalizers/test_json_relational.py | 21 +++++------- tests/common/schema/custom_normalizers.py | 4 +-- 4 files changed, 34 insertions(+), 35 deletions(-) diff --git a/dlt/common/normalizers/json/__init__.py b/dlt/common/normalizers/json/__init__.py index 0b26371078..81564972c9 100644 --- a/dlt/common/normalizers/json/__init__.py +++ b/dlt/common/normalizers/json/__init__.py @@ -1,7 +1,7 @@ -from typing import Iterator, Tuple, Callable, TYPE_CHECKING +from typing import Any, Iterator, Tuple, Callable, TYPE_CHECKING -from dlt.common.typing import TDataItem, StrAny +from dlt.common.typing import DictStrAny, TDataItem, StrAny if TYPE_CHECKING: from dlt.common.schema import Schema @@ -11,4 +11,9 @@ TNormalizedRowIterator = Iterator[Tuple[Tuple[str, str], StrAny]] # normalization function signature -TNormalizeJSONFunc = Callable[["Schema", TDataItem, str], TNormalizedRowIterator] +TNormalizeJSONFunc = Callable[["Schema", TDataItem, str, str], TNormalizedRowIterator] + + +def wrap_in_dict(item: Any) -> DictStrAny: + """Wraps `item` that is not a dictionary into dictionary that can be json normalized""" + return {"value": item} diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 2c8a5864a5..3fb67a9364 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -5,8 +5,8 @@ from dlt.common.schema.typing import TColumnSchema, TColumnName, TSimpleRegex from dlt.common.schema.utils import column_name_validator from dlt.common.utils import uniq_id, digest128 -from dlt.common.normalizers.json import TNormalizedRowIterator -from dlt.common.sources import DLT_METADATA_FIELD, TEventDLTMeta, get_table_name +from dlt.common.normalizers.json import TNormalizedRowIterator, wrap_in_dict +from dlt.common.sources import TEventDLTMeta from dlt.common.validation import validate_dict @@ -40,7 +40,7 @@ class JSONNormalizerConfig(TypedDict, total=True): # for those paths the complex nested objects should be left in place def _is_complex_type(schema: Schema, table_name: str, field_name: str, _r_lvl: int) -> bool: # turn everything at the recursion level into complex type - max_nesting = schema._normalizers_config["json"].get("config", {}).get("max_nesting", 1000) + max_nesting = (schema._normalizers_config["json"].get("config") or {}).get("max_nesting", 1000) assert _r_lvl <= max_nesting if _r_lvl == max_nesting: return True @@ -48,7 +48,7 @@ def _is_complex_type(schema: Schema, table_name: str, field_name: str, _r_lvl: i column: TColumnSchema = None table = schema._schema_tables.get(table_name) if table: - column = table["columns"].get(field_name, None) + column = table["columns"].get(field_name) if column is None: data_type = schema.get_preferred_type(field_name) else: @@ -103,14 +103,14 @@ def _get_content_hash(schema: Schema, table: str, row: StrAny) -> str: def _get_propagated_values(schema: Schema, table: str, row: TEventRow, is_top_level: bool) -> StrAny: - config: JSONNormalizerConfigPropagation = schema._normalizers_config["json"].get("config", {}).get("propagation", None) + config: JSONNormalizerConfigPropagation = (schema._normalizers_config["json"].get("config") or {}).get("propagation", None) extend: DictStrAny = {} if config: # mapping(k:v): propagate property with name "k" as property with name "v" in child table mappings: DictStrStr = {} if is_top_level: - mappings.update(config.get("root", {})) - if table in config.get("tables", {}): + mappings.update(config.get("root") or {}) + if table in (config.get("tables") or {}): mappings.update(config["tables"][table]) # look for keys and create propagation as values for prop_from, prop_as in mappings.items(): @@ -147,7 +147,9 @@ def _normalize_list( else: # list of simple types child_row_hash = _get_child_row_hash(parent_row_id, table, idx) - e = _add_linking({"value": v, "_dlt_id": child_row_hash}, extend, parent_row_id, idx) + wrap_v = wrap_in_dict(v) + wrap_v["_dlt_id"] = child_row_hash + e = _add_linking(wrap_v, extend, parent_row_id, idx) _extend_row(extend, e) yield (table, parent_table), e @@ -199,11 +201,11 @@ def _normalize_row( def extend_schema(schema: Schema) -> None: # validate config - config = schema._normalizers_config["json"].get("config", {}) + config = schema._normalizers_config["json"].get("config") or {} validate_dict(JSONNormalizerConfig, config, "./normalizers/json/config", validator_f=column_name_validator(schema.normalize_column_name)) # quick check to see if hints are applied - default_hints = schema.settings.get("default_hints", {}) + default_hints = schema.settings.get("default_hints") or {} if "not_null" in default_hints and "^_dlt_id$" in default_hints["not_null"]: return # add hints @@ -219,14 +221,9 @@ def extend_schema(schema: Schema) -> None: ) -def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str) -> TNormalizedRowIterator: +def normalize_data_item(schema: Schema, item: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: # we will extend event with all the fields necessary to load it as root row - event = cast(TEventRowRoot, source_event) + event = cast(TEventRowRoot, item) # identify load id if loaded data must be processed after loading incrementally event["_dlt_load_id"] = load_id - # find table name - table_name = schema.normalize_table_name(get_table_name(event) or schema.name) - # drop dlt metadata before normalizing - event.pop(DLT_METADATA_FIELD, None) # type: ignore - # use event type or schema name as table name, request _dlt_root_id propagation - yield from _normalize_row(schema, cast(TEventRowChild, event), {}, table_name) + yield from _normalize_row(schema, cast(TEventRowChild, event), {}, schema.normalize_table_name(table_name)) diff --git a/tests/common/normalizers/test_json_relational.py b/tests/common/normalizers/test_json_relational.py index b432768d7d..47a383b778 100644 --- a/tests/common/normalizers/test_json_relational.py +++ b/tests/common/normalizers/test_json_relational.py @@ -4,7 +4,6 @@ from dlt.common.utils import digest128, uniq_id from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.sources import DLT_METADATA_FIELD, with_table_name from dlt.common.normalizers.json.relational import JSONNormalizerConfigPropagation, _flatten, _get_child_row_hash, _normalize_row, normalize_data_item @@ -519,13 +518,13 @@ def test_complex_types_for_recursion_level(schema: Schema) -> None: "lo": [{"e": {"v": 1}}] # , {"e": {"v": 2}}, {"e":{"v":3 }} }] } - n_rows_nl = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows_nl = list(schema.normalize_data_item(schema, row, "load_id", "default")) # all nested elements were yielded assert ["default", "default__f", "default__f__l", "default__f__lo"] == [r[0][0] for r in n_rows_nl] # set max nesting to 0 schema._normalizers_config["json"]["config"]["max_nesting"] = 0 - n_rows = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows = list(schema.normalize_data_item(schema, row, "load_id", "default")) # the "f" element is left as complex type and not normalized assert len(n_rows) == 1 assert n_rows[0][0][0] == "default" @@ -534,7 +533,7 @@ def test_complex_types_for_recursion_level(schema: Schema) -> None: # max nesting 1 schema._normalizers_config["json"]["config"]["max_nesting"] = 1 - n_rows = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows = list(schema.normalize_data_item(schema, row, "load_id", "default")) assert len(n_rows) == 2 assert ["default", "default__f"] == [r[0][0] for r in n_rows] # on level f, "l" and "lo" are not normalized @@ -545,7 +544,7 @@ def test_complex_types_for_recursion_level(schema: Schema) -> None: # max nesting 2 schema._normalizers_config["json"]["config"]["max_nesting"] = 2 - n_rows = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows = list(schema.normalize_data_item(schema, row, "load_id", "default")) assert len(n_rows) == 4 # in default__f__lo the dicts that would be flattened are complex types last_row = n_rows[3] @@ -553,7 +552,7 @@ def test_complex_types_for_recursion_level(schema: Schema) -> None: # max nesting 3 schema._normalizers_config["json"]["config"]["max_nesting"] = 3 - n_rows = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows = list(schema.normalize_data_item(schema, row, "load_id", "default")) assert n_rows_nl == n_rows @@ -570,13 +569,11 @@ def test_extract_with_table_name_meta() -> None: } # force table name rows = list( - normalize_data_item(create_schema_with_name("discord"), with_table_name(row, "channel"), "load_id") + normalize_data_item(create_schema_with_name("discord"), row, "load_id", "channel") ) # table is channel assert rows[0][0][0] == "channel" normalized_row = rows[0][1] - # _dlt_meta must be removed must be removed - assert DLT_METADATA_FIELD not in normalized_row assert normalized_row["guild_id"] == "815421435900198962" assert "_dlt_id" in normalized_row assert normalized_row["_dlt_load_id"] == "load_id" @@ -588,7 +585,7 @@ def test_table_name_meta_normalized() -> None: } # force table name rows = list( - normalize_data_item(create_schema_with_name("discord"), with_table_name(row, "channelSURFING"), "load_id") + normalize_data_item(create_schema_with_name("discord"), row, "load_id", "channelSURFING") ) # table is channel assert rows[0][0][0] == "channel_surfing" @@ -607,7 +604,7 @@ def test_parse_with_primary_key() -> None: "wo_id": [1, 2, 3] }] } - rows = list(normalize_data_item(schema, row, "load_id")) + rows = list(normalize_data_item(schema, row, "load_id", "discord")) # get root root = next(t[1] for t in rows if t[0][0] == "discord") assert root["_dlt_id"] == digest128("817949077341208606") @@ -633,7 +630,7 @@ def test_parse_with_primary_key() -> None: def test_keeps_none_values() -> None: row = {"a": None, "timestamp": 7} - rows = list(normalize_data_item(create_schema_with_name("other"), row, "1762162.1212")) + rows = list(normalize_data_item(create_schema_with_name("other"), row, "1762162.1212", "other")) table_name = rows[0][0][0] assert table_name == "other" normalized_row = rows[0][1] diff --git a/tests/common/schema/custom_normalizers.py b/tests/common/schema/custom_normalizers.py index 5709dd8db2..a69b22df1f 100644 --- a/tests/common/schema/custom_normalizers.py +++ b/tests/common/schema/custom_normalizers.py @@ -22,5 +22,5 @@ def extend_schema(schema: Schema) -> None: d_h["not_null"] = json_config["not_null"] -def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str) -> TNormalizedRowIterator: - yield ("table", None), source_event +def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str, table_name) -> TNormalizedRowIterator: + yield (table_name, None), source_event From b07fb5466df827ba9521161fb06d7c843369c357 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Sep 2022 17:49:47 +0200 Subject: [PATCH 13/66] adds performance improvements to Schema, 3x normalization speed gained by just removing runtime protocol check --- dlt/common/schema/schema.py | 36 ++++++++++++-------------- dlt/common/schema/utils.py | 37 ++++++++++++++++++--------- tests/common/schema/test_filtering.py | 5 ++-- tests/common/schema/test_inference.py | 6 ++--- tests/common/schema/test_schema.py | 6 ++--- 5 files changed, 50 insertions(+), 40 deletions(-) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index e69a717b65..31b20efb69 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -80,7 +80,7 @@ def from_dict(cls, d: DictStrAny) -> "Schema": # create new instance from dict self: Schema = cls(stored_schema["name"], normalizers=stored_schema.get("normalizers", None)) - self._schema_tables = stored_schema.get("tables", {}) + self._schema_tables = stored_schema.get("tables") or {} if Schema.VERSION_TABLE_NAME not in self._schema_tables: raise SchemaCorruptedException(f"Schema must contain table {Schema.VERSION_TABLE_NAME}") if Schema.LOADS_TABLE_NAME not in self._schema_tables: @@ -88,7 +88,7 @@ def from_dict(cls, d: DictStrAny) -> "Schema": self._stored_version = stored_schema["version"] self._stored_version_hash = stored_schema["version_hash"] self._imported_version_hash = stored_schema.get("imported_version_hash") - self._settings = stored_schema.get("settings", {}) + self._settings = stored_schema.get("settings") or {} # compile regexes self._compile_regexes() @@ -139,7 +139,7 @@ def _exclude(path: str, excludes: Sequence[REPattern], includes: Sequence[REPatt excludes = self._compiled_excludes.get(c_t) # only if there's possibility to exclude, continue if excludes: - includes = self._compiled_includes.get(c_t, []) + includes = self._compiled_includes.get(c_t) or [] for field_name in list(row.keys()): path = self.normalize_make_path(*branch[i:], field_name) if _exclude(path, excludes, includes): @@ -150,12 +150,14 @@ def _exclude(path: str, excludes: Sequence[REPattern], includes: Sequence[REPatt break return row - def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[StrAny, TPartialTableSchema]: + def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[DictStrAny, TPartialTableSchema]: # get existing or create a new table - table = self._schema_tables.get(table_name, utils.new_table(table_name, parent_table)) + updated_table_partial: TPartialTableSchema = None + table = self._schema_tables.get(table_name) + if not table: + table = utils.new_table(table_name, parent_table) table_columns = table["columns"] - partial_table: TPartialTableSchema = None new_row: DictStrAny = {} for col_name, v in row.items(): # skip None values, we should infer the types later @@ -166,12 +168,13 @@ def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[S new_col_name, new_col_def, new_v = self._coerce_non_null_value(table_columns, table_name, col_name, v) new_row[new_col_name] = new_v if new_col_def: - if not partial_table: - partial_table = copy(table) - partial_table["columns"] = {} - partial_table["columns"][new_col_name] = new_col_def + if not updated_table_partial: + # create partial table with only the new columns + updated_table_partial = copy(table) + updated_table_partial["columns"] = {} + updated_table_partial["columns"][new_col_name] = new_col_def - return new_row, partial_table + return new_row, updated_table_partial def update_schema(self, partial_table: TPartialTableSchema) -> None: table_name = partial_table["name"] @@ -188,12 +191,6 @@ def update_schema(self, partial_table: TPartialTableSchema) -> None: # add the whole new table to SchemaTables self._schema_tables[table_name] = partial_table else: - # partial_w_d = partial_table.get("write_disposition") - # if table.get("parent") and not table.get("write_disposition") and partial_w_d: - # # get write disposition recursively for existing table and check if those fit - # existing_w_d = self.get_write_disposition(table_name) - # if existing_w_d != partial_table.get("write_disposition"): - # raise TablePropertiesConflictException(table_name, "write_disposition", existing_w_d, partial_w_d) # merge tables performing additional checks utils.merge_tables(table, partial_table) @@ -356,7 +353,7 @@ def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: # infer type or get it from existing table col_type = existing_column.get("data_type") if existing_column else self._infer_column_type(v, col_name) - # get real python type + # get data type of value py_type = utils.py_type_to_sc_type(type(v)) # and coerce type if inference changed the python type try: @@ -373,7 +370,8 @@ def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: return self._coerce_non_null_value(table_columns, table_name, variant_col_name, v, final=True) # if coerced value is variant, then extract variant value - if isinstance(coerced_v, SupportsVariant): + # note: checking runtime protocols with isinstance(coerced_v, SupportsVariant): is extremely slow so we check if callable as every variant is callable + if callable(coerced_v) and isinstance(coerced_v, SupportsVariant): coerced_v = coerced_v() if isinstance(coerced_v, tuple): # variant recovered so call recursively with variant column name and variant value diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 08305ed7ce..ec35c3bb31 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -82,13 +82,13 @@ def generate_version_hash(stored_schema: TStoredSchema) -> str: content = json.dumps(schema_copy, sort_keys=True) h = hashlib.sha3_256(content.encode("utf-8")) # additionally check column order - table_names = sorted(schema_copy.get("tables", {}).keys()) + table_names = sorted((schema_copy.get("tables") or {}).keys()) if table_names: for tn in table_names: t = schema_copy["tables"][tn] h.update(tn.encode("utf-8")) # add column names to hash in order - for cn in t.get("columns", {}).keys(): + for cn in (t.get("columns") or {}).keys(): h.update(cn.encode("utf-8")) return base64.b64encode(h.digest()).decode('ascii') @@ -268,26 +268,39 @@ def autodetect_sc_type(detection_fs: Sequence[TTypeDetections], t: Type[Any], v: def py_type_to_sc_type(t: Type[Any]) -> TDataType: - if issubclass(t, float): + # start with most popular types + if t is str: + return "text" + if t is float: return "double" # bool is subclass of int so must go first - elif t is bool: + if t is bool: return "bool" - elif issubclass(t, int): + if t is int: return "bigint" - elif issubclass(t, bytes): - return "binary" - elif issubclass(t, (dict, list)): + if issubclass(t, (dict, list)): return "complex" + + # those are special types that will not be present in json loaded dict # wei is subclass of decimal and must be checked first - elif issubclass(t, Wei): + if issubclass(t, Wei): return "wei" - elif issubclass(t, Decimal): + if issubclass(t, Decimal): return "decimal" - elif issubclass(t, datetime.datetime): + if issubclass(t, datetime.datetime): return "timestamp" - else: + + # check again for subclassed basic types + if issubclass(t, str): return "text" + if issubclass(t, float): + return "double" + if issubclass(t, int): + return "bigint" + if issubclass(t, bytes): + return "binary" + + return "text" def coerce_type(to_type: TDataType, from_type: TDataType, value: Any) -> Any: diff --git a/tests/common/schema/test_filtering.py b/tests/common/schema/test_filtering.py index ddec6ac34e..03a2e6de4a 100644 --- a/tests/common/schema/test_filtering.py +++ b/tests/common/schema/test_filtering.py @@ -1,7 +1,6 @@ import pytest from copy import deepcopy from dlt.common.schema.exceptions import ParentTableNotFoundException -from dlt.common.sources import with_table_name from dlt.common.typing import StrAny from dlt.common.schema import Schema @@ -73,7 +72,7 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: updates = [] - for (t, p), row in schema.normalize_data_item(schema, with_table_name(source_row, "event_bot"), "load_id"): + for (t, p), row in schema.normalize_data_item(schema, source_row, "load_id", "event_bot"): row = schema.filter_row(t, row) if not row: # those rows are fully removed @@ -98,7 +97,7 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: _add_excludes(schema) schema.get_table("event_bot")["filters"]["includes"].extend(["re:^metadata___dlt_", "re:^metadata__elvl1___dlt_"]) schema._compile_regexes() - for (t, p), row in schema.normalize_data_item(schema, with_table_name(source_row, "event_bot"), "load_id"): + for (t, p), row in schema.normalize_data_item(schema, source_row, "load_id", "event_bot"): row = schema.filter_row(t, row) if p is None: assert "_dlt_id" in row diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index f53b99f7b4..efb2650ce8 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -208,7 +208,7 @@ def test_coerce_complex_variant(schema: Schema) -> None: def test_supports_variant_pua_decode(schema: Schema) -> None: rows = load_json_case("pua_encoded_row") - normalized_row = list(schema.normalize_data_item(schema, rows[0], "0912uhj222")) + normalized_row = list(schema.normalize_data_item(schema, rows[0], "0912uhj222", "event")) # pua encoding still present assert normalized_row[0][1]["wad"].startswith("") # decode pua @@ -223,7 +223,7 @@ def test_supports_variant(schema: Schema) -> None: rows = [{"evm": Wei.from_int256(2137*10**16, decimals=18)}, {"evm": Wei.from_int256(2**256-1)}] normalized_rows = [] for row in rows: - normalized_rows.extend(schema.normalize_data_item(schema, row, "128812.2131")) + normalized_rows.extend(schema.normalize_data_item(schema, row, "128812.2131", "event")) # row 1 contains Wei assert isinstance(normalized_rows[0][1]["evm"], Wei) assert normalized_rows[0][1]["evm"] == Wei("21.37") @@ -281,7 +281,7 @@ def __call__(self) -> Any: rows = [{"pv": PureVariant(3377)}, {"pv": PureVariant(21.37)}] normalized_rows = [] for row in rows: - normalized_rows.extend(schema.normalize_data_item(schema, row, "128812.2131")) + normalized_rows.extend(schema.normalize_data_item(schema, row, "128812.2131", "event")) assert normalized_rows[0][1]["pv"]() == 3377 assert normalized_rows[1][1]["pv"]() == ("text", 21.37) # first normalized row fits into schema (pv is int) diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index d664b14cf4..4ffb0b2399 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -394,8 +394,8 @@ def assert_new_schema_values_custom_normalizers(schema: Schema) -> None: # assumes elements are normalized assert schema.normalize_make_path("A", "B", "!C") == "A__B__!C" assert schema.normalize_break_path("A__B__!C") == ["A", "B", "!C"] - row = list(schema.normalize_data_item(schema, {"bool": True}, "load_id")) - assert row[0] == (("table", None), {"bool": True}) + row = list(schema.normalize_data_item(schema, {"bool": True}, "load_id", "a_table")) + assert row[0] == (("a_table", None), {"bool": True}) def assert_new_schema_values(schema: Schema) -> None: @@ -417,7 +417,7 @@ def assert_new_schema_values(schema: Schema) -> None: # assumes elements are normalized assert schema.normalize_make_path("A", "B", "!C") == "A__B__!C" assert schema.normalize_break_path("A__B__!C") == ["A", "B", "!C"] - schema.normalize_data_item(schema, {}, "load_id") + schema.normalize_data_item(schema, {}, "load_id", schema.name) # check default tables tables = schema.tables assert "_dlt_version" in tables From d8d8563ac79702da7333c7a40b2a965cb91d97f7 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Sep 2022 17:51:40 +0200 Subject: [PATCH 14/66] removes runtime protocol check from schema altogether --- dlt/common/schema/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 31b20efb69..ba1a1af317 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -371,7 +371,7 @@ def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: # if coerced value is variant, then extract variant value # note: checking runtime protocols with isinstance(coerced_v, SupportsVariant): is extremely slow so we check if callable as every variant is callable - if callable(coerced_v) and isinstance(coerced_v, SupportsVariant): + if callable(coerced_v): # and isinstance(coerced_v, SupportsVariant): coerced_v = coerced_v() if isinstance(coerced_v, tuple): # variant recovered so call recursively with variant column name and variant value From f98ddbbfe13bc3a5d6761970475bc3ab88e84590 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Sep 2022 17:53:25 +0200 Subject: [PATCH 15/66] implements data item storage and adds to normalize and load storages as mixins --- dlt/common/storages/__init__.py | 1 + dlt/common/storages/data_item_storage.py | 38 +++++++++++++++++++++ dlt/common/storages/load_storage.py | 34 ++++++++++-------- dlt/common/storages/normalize_storage.py | 26 +------------- tests/load/bigquery/test_bigquery_client.py | 2 +- tests/tools/create_storages.py | 2 +- 6 files changed, 62 insertions(+), 41 deletions(-) create mode 100644 dlt/common/storages/data_item_storage.py diff --git a/dlt/common/storages/__init__.py b/dlt/common/storages/__init__.py index 8cf7b71216..9cae20f688 100644 --- a/dlt/common/storages/__init__.py +++ b/dlt/common/storages/__init__.py @@ -3,3 +3,4 @@ from .normalize_storage import NormalizeStorage # noqa: F401 from .versioned_storage import VersionedStorage # noqa: F401 from .load_storage import LoadStorage # noqa: F401 +from .data_item_storage import DataItemStorage # noqa: F401 diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py new file mode 100644 index 0000000000..9c4e05249c --- /dev/null +++ b/dlt/common/storages/data_item_storage.py @@ -0,0 +1,38 @@ +from typing import Dict, Any +from abc import ABC, abstractmethod + +from dlt.common import logger +from dlt.common.schema import TTableSchemaColumns +from dlt.common.sources import TDirectDataItem +from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter + + +class DataItemStorage(ABC): + def __init__(self, load_file_type: TLoaderFileFormat, *args: Any) -> None: + self.loader_file_format = load_file_type + self.buffered_writers: Dict[str, BufferedDataWriter] = {} + super().__init__(*args) + + def write_data_item(self, load_id: str, schema_name: str, table_name: str, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: + # unique writer id + writer_id = f"{load_id}.{schema_name}.{table_name}" + writer = self.buffered_writers.get(writer_id, None) + if not writer: + # assign a jsonl writer for each table + path = self._get_data_item_path_template(load_id, schema_name, table_name) + writer = BufferedDataWriter(self.loader_file_format, path) + self.buffered_writers[writer_id] = writer + # write item(s) + writer.write_data_item(item, columns) + + def close_writers(self, extract_id: str) -> None: + # flush and close all files + for name, writer in self.buffered_writers.items(): + if name.startswith(extract_id): + logger.debug(f"Closing writer for {name} with file {writer._file} and actual name {writer._file_name}") + writer.close_writer() + + @abstractmethod + def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: + # note: use %s for file id to create required template format + pass diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 456ae9d538..c8b139800c 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -4,14 +4,14 @@ from typing import Iterable, NamedTuple, Literal, Optional, Sequence, Set, Tuple, Type, get_args from dlt.common import json, pendulum +from dlt.common.typing import DictStrAny, StrAny from dlt.common.file_storage import FileStorage from dlt.common.data_writers import TLoaderFileFormat, DataWriter from dlt.common.configuration import LoadVolumeConfiguration from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema, TSchemaUpdate, TTableSchemaColumns from dlt.common.storages.versioned_storage import VersionedStorage -from dlt.common.typing import DictStrAny, StrAny - +from dlt.common.storages.data_item_storage import DataItemStorage from dlt.common.storages.exceptions import JobWithUnsupportedWriterException @@ -24,7 +24,7 @@ class TParsedJobFileName(NamedTuple): file_format: TLoaderFileFormat -class LoadStorage(VersionedStorage): +class LoadStorage(DataItemStorage, VersionedStorage): STORAGE_VERSION = "1.0.0" NORMALIZED_FOLDER = "normalized" # folder within the volume where load packages are stored @@ -52,10 +52,13 @@ def __init__( raise TerminalValueError(supported_file_formats) if preferred_file_format not in supported_file_formats: raise TerminalValueError(preferred_file_format) - self.preferred_file_format = preferred_file_format self.supported_file_formats = supported_file_formats self.delete_completed_jobs = C.DELETE_COMPLETED_JOBS - super().__init__(LoadStorage.STORAGE_VERSION, is_owner, FileStorage(C.LOAD_VOLUME_PATH, "t", makedirs=is_owner)) + super().__init__( + preferred_file_format, + LoadStorage.STORAGE_VERSION, + is_owner, FileStorage(C.LOAD_VOLUME_PATH, "t", makedirs=is_owner) + ) if is_owner: self.initialize_storage() @@ -74,15 +77,15 @@ def create_temp_load_package(self, load_id: str) -> None: self.storage.create_folder(join(load_id, LoadStorage.FAILED_JOBS_FOLDER)) self.storage.create_folder(join(load_id, LoadStorage.STARTED_JOBS_FOLDER)) + def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> str: + file_name = self.build_job_file_name(table_name, "%s", with_extension=False) + return self.storage.make_full_path(join(load_id, LoadStorage.NEW_JOBS_FOLDER, file_name)) + def write_temp_job_file(self, load_id: str, table_name: str, table: TTableSchemaColumns, file_id: str, rows: Sequence[StrAny]) -> str: - file_name = self.build_job_file_name(table_name, file_id) - with self.storage.open_file(join(load_id, LoadStorage.NEW_JOBS_FOLDER, file_name), mode="w") as f: - writer = DataWriter.from_file_format(self.preferred_file_format, f) + file_name = self._get_data_item_path_template(load_id, None, table_name) % file_id + "." + self.loader_file_format + with self.storage.open_file(file_name, mode="w") as f: + writer = DataWriter.from_file_format(self.loader_file_format, f) writer.write_all(table, rows) - # if self.preferred_file_format == "jsonl": - # write_jsonl(f, rows) - # elif self.preferred_file_format == "insert_values": - # write_insert_values(f, rows, table.keys()) return Path(file_name).name def load_package_schema(self, load_id: str) -> Schema: @@ -214,11 +217,14 @@ def _get_job_folder_path(self, load_id: str, folder: TWorkingFolder) -> str: def _get_job_file_path(self, load_id: str, folder: TWorkingFolder, file_name: str) -> str: return join(self._get_job_folder_path(load_id, folder), file_name) - def build_job_file_name(self, table_name: str, file_id: str, retry_count: int = 0, validate_components: bool = True) -> str: + def build_job_file_name(self, table_name: str, file_id: str, retry_count: int = 0, validate_components: bool = True, with_extension: bool = True) -> str: if validate_components: FileStorage.validate_file_name_component(table_name) FileStorage.validate_file_name_component(file_id) - return f"{table_name}.{file_id}.{int(retry_count)}.{self.preferred_file_format}" + fn = f"{table_name}.{file_id}.{int(retry_count)}" + if with_extension: + return fn + f".{self.loader_file_format}" + return fn @staticmethod def parse_job_file_name(file_name: str) -> TParsedJobFileName: diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 3cdbf4a1c5..485bc480a1 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -21,6 +21,7 @@ class NormalizeStorage(VersionedStorage): def __init__(self, is_owner: bool, C: Type[NormalizeVolumeConfiguration]) -> None: super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(C.NORMALIZE_VOLUME_PATH, "t", makedirs=is_owner)) + self.CONFIG = C if is_owner: self.initialize_storage() @@ -33,31 +34,6 @@ def list_files_to_normalize_sorted(self) -> Sequence[str]: def get_grouped_iterator(self, files: Sequence[str]) -> "groupby[str, str]": return groupby(files, lambda f: NormalizeStorage.get_schema_name(f)) - @staticmethod - def chunk_by_events(files: Sequence[str], max_events: int, processing_cores: int) -> List[Sequence[str]]: - return [files] - - # # should distribute ~ N events evenly among m cores with fallback for small amounts of events - - # def count_events(file_name : str) -> int: - # # return event count from file name - # return NormalizeStorage.get_events_count(file_name) - - # counts = list(map(count_events, files)) - # # make a list of files containing ~max_events - # events_count = 0 - # m = 0 - # while events_count < max_events and m < len(files): - # events_count += counts[m] - # m += 1 - # processing_chunks = round(m / processing_cores) - # if processing_chunks == 0: - # # return one small chunk - # return [files] - # else: - # # should return ~ amount of chunks to fill all the cores - # return list(chunks(files[:m], processing_chunks)) - @staticmethod def get_schema_name(file_name: str) -> str: return NormalizeStorage.parse_normalize_file_name(file_name).schema_name diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index e70b0bb9dc..3d33eba9ee 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -78,7 +78,7 @@ def test_bigquery_location(location: str, file_storage: FileStorage) -> None: job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) # start a job from the same file. it should fallback to retrieve job silently - client.start_file_load(client.schema.get_table(user_table_name), file_storage._make_path(job.file_name())) + client.start_file_load(client.schema.get_table(user_table_name), file_storage.make_full_path(job.file_name())) canonical_name = client.sql_client.make_qualified_table_name(user_table_name) t = client.sql_client.native_connection.get_table(canonical_name) assert t.location == location diff --git a/tests/tools/create_storages.py b/tests/tools/create_storages.py index f2ae8d2260..680f3dd61f 100644 --- a/tests/tools/create_storages.py +++ b/tests/tools/create_storages.py @@ -4,4 +4,4 @@ NormalizeStorage(True, NormalizeVolumeConfiguration) LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) -SchemaStorage(SchemaVolumeConfiguration.SCHEMA_VOLUME_PATH, makedirs=True) +SchemaStorage(SchemaVolumeConfiguration, makedirs=True) From 4d3861b1e4adea665bd59b95fd6301793513a529 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Sep 2022 17:54:41 +0200 Subject: [PATCH 16/66] rewrites Normalize to work with new extracted jsonl files and use buffered data writers --- dlt/normalize/configuration.py | 1 - dlt/normalize/normalize.py | 184 +++++++++--------- ...s.9c1d9b504ea240a482b007788d5cd61c_2.json} | 0 ...t.bot_load_metadata_2987398237498798.json} | 0 ...cted.json => event.event.many_load_2.json} | 0 ... event.event.slot_session_metadata_1.json} | 0 ...cted.json => event.event.user_load_1.json} | 0 ...json => event.event.user_load_v228_1.json} | 0 tests/normalize/mock_rasa_json_normalizer.py | 15 +- tests/normalize/test_normalize.py | 98 +++++----- 10 files changed, 157 insertions(+), 141 deletions(-) rename tests/normalize/cases/{ethereum_blocks_9c1d9b504ea240a482b007788d5cd61c_2.extracted.json => ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json} (100%) rename tests/normalize/cases/{event_bot_load_metadata_1.extracted.json => event.event.bot_load_metadata_2987398237498798.json} (100%) rename tests/normalize/cases/{event_many_load_2.extracted.json => event.event.many_load_2.json} (100%) rename tests/normalize/cases/{event_slot_session_metadata_1.extracted.json => event.event.slot_session_metadata_1.json} (100%) rename tests/normalize/cases/{event_user_load_1.extracted.json => event.event.user_load_1.json} (100%) rename tests/normalize/cases/{event_user_load_v228_1.extracted.json => event.event.user_load_v228_1.json} (100%) diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index b716f5f613..ef8bda141d 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -12,7 +12,6 @@ class NormalizeConfiguration(PoolRunnerConfiguration, NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration): - MAX_EVENTS_IN_CHUNK: int = 40000 # maximum events to be processed in single chunk LOADER_FILE_FORMAT: TLoaderFileFormat = "jsonl" # jsonp or insert commands will be generated POOL_TYPE: TPoolType = "process" diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 3b4993bcf9..81761d881f 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -6,27 +6,32 @@ from dlt.common import pendulum, signals, json, logger from dlt.common.json import custom_pua_decode from dlt.cli import TRunnerArgs +from dlt.common.normalizers.json import wrap_in_dict from dlt.common.runners import TRunMetrics, Runnable, run_pool, initialize_runner, workermethod +from dlt.common.runners.runnable import configuredworker +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages import NormalizeStorage, SchemaStorage, LoadStorage from dlt.common.telemetry import get_logging_extras -from dlt.common.utils import uniq_id -from dlt.common.typing import TDataItem +from dlt.common.typing import StrAny, TDataItem from dlt.common.exceptions import PoolException from dlt.common.schema import TSchemaUpdate, Schema from dlt.common.schema.exceptions import CannotCoerceColumnException +from dlt.common.utils import uniq_id from dlt.normalize.configuration import configuration, NormalizeConfiguration -TMapFuncRV = Tuple[List[TSchemaUpdate], List[Sequence[str]]] -TMapFuncType = Callable[[str, str, Sequence[str]], TMapFuncRV] +# normalize worker wrapping function (map_parallel, map_single) return type +TMapFuncRV = Tuple[int, List[TSchemaUpdate], List[Sequence[str]]] # (total items processed, list of schema updates, list of processed files) +# normalize worker wrapping function signature +TMapFuncType = Callable[[str, str, Sequence[str]], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process) class Normalize(Runnable[ProcessPool]): # make our gauges static - event_counter: Counter = None - event_gauge: Gauge = None + item_counter: Counter = None + item_gauge: Gauge = None schema_version_gauge: Gauge = None load_package_counter: Counter = None @@ -45,13 +50,13 @@ def __init__(self, C: Type[NormalizeConfiguration], collector: CollectorRegistry self.create_gauges(collector) except ValueError as v: # ignore re-creation of gauges - if "Duplicated timeseries" not in str(v): + if "Duplicated time-series" not in str(v): raise @staticmethod def create_gauges(registry: CollectorRegistry) -> None: - Normalize.event_counter = Counter("normalize_event_count", "Events processed in normalize", ["schema"], registry=registry) - Normalize.event_gauge = Gauge("normalize_last_events", "Number of events processed in last run", ["schema"], registry=registry) + Normalize.item_counter = Counter("normalize_item_count", "Items processed in normalize", ["schema"], registry=registry) + Normalize.item_gauge = Gauge("normalize_last_items", "Number of items processed in last run", ["schema"], registry=registry) Normalize.schema_version_gauge = Gauge("normalize_schema_version", "Current schema version", ["schema"], registry=registry) Normalize.load_package_counter = Gauge("normalize_load_packages_created_count", "Count of load package created", ["schema"], registry=registry) @@ -60,9 +65,11 @@ def create_storages(self) -> None: # normalize saves in preferred format but can read all supported formats self.load_storage = LoadStorage(True, self.CONFIG, self.CONFIG.LOADER_FILE_FORMAT, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) - def load_or_create_schema(self, schema_name: str) -> Schema: + + @staticmethod + def load_or_create_schema(schema_storage: SchemaStorage, schema_name: str) -> Schema: try: - schema = self.schema_storage.load_schema(schema_name) + schema = schema_storage.load_schema(schema_name) logger.info(f"Loaded schema with name {schema_name} with version {schema.stored_version}") except SchemaNotFoundError: schema = Schema(schema_name) @@ -70,74 +77,79 @@ def load_or_create_schema(self, schema_name: str) -> Schema: return schema @staticmethod - @workermethod - def w_normalize_files(self: "Normalize", schema_name: str, load_id: str, events_files: Sequence[str]) -> TSchemaUpdate: - normalized_data: Dict[str, List[Any]] = {} + @configuredworker + def w_normalize_files(CONFIG: Type[NormalizeConfiguration], schema_name: str, load_id: str, extracted_items_files: Sequence[str]) -> Tuple[TSchemaUpdate, int]: + schema = Normalize.load_or_create_schema(SchemaStorage(CONFIG, makedirs=False), schema_name) + load_storage = LoadStorage(False, CONFIG, CONFIG.LOADER_FILE_FORMAT, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + normalize_storage = NormalizeStorage(False, CONFIG) schema_update: TSchemaUpdate = {} - schema = self.load_or_create_schema(schema_name) - file_id = uniq_id(5) - - # process all event files and store rows in memory - for events_file in events_files: - i: int = 0 - event: TDataItem = None - try: - logger.debug(f"Processing events file {events_file} in load_id {load_id} with file_id {file_id}") - with self.normalize_storage.storage.open_file(events_file) as f: - events: Sequence[TDataItem] = json.load(f) - for i, event in enumerate(events): - for (table_name, parent_table), row in schema.normalize_data_item(schema, event, load_id): - # filter row, may eliminate some or all fields - row = schema.filter_row(table_name, row) - # do not process empty rows - if row: - # decode pua types - for k, v in row.items(): - row[k] = custom_pua_decode(v) # type: ignore - # check if schema can be updated - row, partial_table = schema.coerce_row(table_name, parent_table, row) - if partial_table: - # update schema and save the change - schema.update_schema(partial_table) - table_updates = schema_update.setdefault(table_name, []) - table_updates.append(partial_table) - # store row - rows = normalized_data.setdefault(table_name, []) - rows.append(row) - if i % 100 == 0: - logger.debug(f"Processed {i} of {len(events)} events") - except Exception: - logger.exception(f"Exception when processing file {events_file}, event idx {i}") - logger.debug(f"Affected event: {event}") - raise PoolException("normalize_files", events_file) - - # save rows and return schema changes to be gathered in parent process - for table_name, rows in normalized_data.items(): - # save into new jobs to processed as load - table = schema.get_table_columns(table_name) - self.load_storage.write_temp_job_file(load_id, table_name, table, file_id, rows) - - return schema_update + column_schemas: Dict[str, TTableSchemaColumns] = {} # quick access to column schema for writers below + total_items = 0 + + # process all files with data items and write to buffered item storage + try: + for extracted_items_file in extracted_items_files: + line_no: int = 0 + item: TDataItem = None + parent_table_name = NormalizeStorage.parse_normalize_file_name(extracted_items_file).table_name + logger.debug(f"Processing extracted items in {extracted_items_file} in load_id {load_id} with table name {parent_table_name} and schema {schema_name}") + with normalize_storage.storage.open_file(extracted_items_file) as f: + # enumerate jsonl file line by line + for line_no, line in enumerate(f): + item = json.loads(line) + if not isinstance(item, dict): + item = wrap_in_dict(item) + for (table_name, parent_table), row in schema.normalize_data_item(schema, item, load_id, parent_table_name): + # filter row, may eliminate some or all fields + row = schema.filter_row(table_name, row) + # do not process empty rows + if row: + # decode pua types + for k, v in row.items(): + row[k] = custom_pua_decode(v) # type: ignore + # coerce row of values into schema table, generating partial table with new columns if any + row, partial_table = schema.coerce_row(table_name, parent_table, row) + if partial_table: + # update schema and save the change + schema.update_schema(partial_table) + table_updates = schema_update.setdefault(table_name, []) + table_updates.append(partial_table) + # get current columns schema + columns = column_schemas.get(table_name) + if not columns: + columns = schema.get_table_columns(table_name) + column_schemas[table_name] = columns + # store row + load_storage.write_data_item(load_id, schema_name, table_name, row, columns) + # count total items + total_items += 1 + if line_no > 0 and line_no % 100 == 0: + logger.debug(f"Processed {line_no} items from file {extracted_items_file}, total items {total_items}") + # if any item found in the file + if item: + logger.debug(f"Processed total {line_no + 1} lines from file {extracted_items_file}, total items {total_items}") + except Exception: + logger.exception(f"Exception when processing file {extracted_items_file}, line {line_no}") + # logger.debug(f"Affected item: {item}") + raise PoolException("normalize_files", extracted_items_file) + finally: + load_storage.close_writers(load_id) + + logger.debug(f"Processed total {total_items} items in {len(extracted_items_files)} files") + + return schema_update, total_items def map_parallel(self, schema_name: str, load_id: str, files: Sequence[str]) -> TMapFuncRV: - # we chunk files in a way to not exceed MAX_EVENTS_IN_CHUNK and split them equally - # between processors - configured_processes = self.pool._processes # type: ignore - chunk_files = NormalizeStorage.chunk_by_events(files, self.CONFIG.MAX_EVENTS_IN_CHUNK, configured_processes) - logger.info(f"Obtained {len(chunk_files)} processing chunks") - # use id of self to pass the self instance. see `Runnable` class docstrings - param_chunk = [(id(self), schema_name, load_id, files) for files in chunk_files] - return self.pool.starmap(Normalize.w_normalize_files, param_chunk), chunk_files + # TODO: maybe we should chunk by file size, now map all files to workers + chunk_files = [files] + param_chunk = [(self.CONFIG.as_dict(), schema_name, load_id, files) for files in chunk_files] + processed_chunks = self.pool.starmap(Normalize.w_normalize_files, param_chunk) + return sum([t[1] for t in processed_chunks]), [t[0] for t in processed_chunks], chunk_files def map_single(self, schema_name: str, load_id: str, files: Sequence[str]) -> TMapFuncRV: - chunk_files = NormalizeStorage.chunk_by_events(files, self.CONFIG.MAX_EVENTS_IN_CHUNK, 1) - # get in one chunk - assert len(chunk_files) == 1 - logger.info(f"Obtained {len(chunk_files)} processing chunks") - # use id of self to pass the self instance. see `Runnable` class docstrings - self_id: Any = id(self) - return [Normalize.w_normalize_files(self_id, schema_name, load_id, chunk_files[0])], chunk_files + processed_chunk = Normalize.w_normalize_files(self.CONFIG, schema_name, load_id, files) + return processed_chunk[1], [processed_chunk[0]], [files] def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> int: updates_count = 0 @@ -151,9 +163,9 @@ def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files: Sequence[str]) -> None: # process files in parallel or in single thread, depending on map_f - schema_updates, chunk_files = map_f(schema_name, load_id, files) + total_items, schema_updates, chunk_files = map_f(schema_name, load_id, files) - schema = self.load_or_create_schema(schema_name) + schema = Normalize.load_or_create_schema(self.schema_storage, schema_name) # gather schema from all manifests, validate consistency and combine updates_count = self.update_schema(schema, schema_updates) self.schema_version_gauge.labels(schema_name).set(schema.version) @@ -171,19 +183,16 @@ def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files logger.info("Committing storage, do not kill this process") # rename temp folder to processing self.load_storage.commit_temp_load_package(load_id) - # delete event files and count events to provide metrics - total_events = 0 - for event_file in chain.from_iterable(chunk_files): # flatten chunks - self.normalize_storage.storage.delete(event_file) - # TODO: get total events from worker function and make stats per table - # total_events += .... + # delete item files to complete commit + for item_file in chain.from_iterable(chunk_files): # flatten chunks + self.normalize_storage.storage.delete(item_file) # log and update metrics logger.info(f"Chunk {load_id} processed") self.load_package_counter.labels(schema_name).inc() - self.event_counter.labels(schema_name).inc(total_events) - self.event_gauge.labels(schema_name).set(total_events) + self.item_counter.labels(schema_name).inc(total_items) + self.item_gauge.labels(schema_name).set(total_items) logger.metrics("Normalize metrics", extra=get_logging_extras( - [self.load_package_counter.labels(schema_name), self.event_counter.labels(schema_name), self.event_gauge.labels(schema_name)])) + [self.load_package_counter.labels(schema_name), self.item_counter.labels(schema_name), self.item_gauge.labels(schema_name)])) def spool_schema_files(self, schema_name: str, files: Sequence[str]) -> str: # normalized files will go here before being atomically renamed @@ -191,9 +200,11 @@ def spool_schema_files(self, schema_name: str, files: Sequence[str]) -> str: self.load_storage.create_temp_load_package(load_id) logger.info(f"Created temp load folder {load_id} on loading volume") + # if pool is not present use map_single method to run normalization in single process + map_parallel_f = self.map_parallel if self.pool else self.map_single try: # process parallel - self.spool_files(schema_name, load_id, self.map_parallel, files) + self.spool_files(schema_name, load_id, map_parallel_f, files) except CannotCoerceColumnException as exc: # schema conflicts resulting from parallel executing logger.warning(f"Parallel schema update conflict, switching to single thread ({str(exc)}") @@ -206,11 +217,10 @@ def spool_schema_files(self, schema_name: str, files: Sequence[str]) -> str: def run(self, pool: ProcessPool) -> TRunMetrics: # keep the pool in class instance self.pool = pool - logger.info("Running file normalizing") # list files and group by schema name, list must be sorted for group by to actually work files = self.normalize_storage.list_files_to_normalize_sorted() - logger.info(f"Found {len(files)} files, will process in chunks of {self.CONFIG.MAX_EVENTS_IN_CHUNK} of events") + logger.info(f"Found {len(files)} files") if len(files) == 0: return TRunMetrics(True, False, 0) # group files by schema diff --git a/tests/normalize/cases/ethereum_blocks_9c1d9b504ea240a482b007788d5cd61c_2.extracted.json b/tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json similarity index 100% rename from tests/normalize/cases/ethereum_blocks_9c1d9b504ea240a482b007788d5cd61c_2.extracted.json rename to tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json diff --git a/tests/normalize/cases/event_bot_load_metadata_1.extracted.json b/tests/normalize/cases/event.event.bot_load_metadata_2987398237498798.json similarity index 100% rename from tests/normalize/cases/event_bot_load_metadata_1.extracted.json rename to tests/normalize/cases/event.event.bot_load_metadata_2987398237498798.json diff --git a/tests/normalize/cases/event_many_load_2.extracted.json b/tests/normalize/cases/event.event.many_load_2.json similarity index 100% rename from tests/normalize/cases/event_many_load_2.extracted.json rename to tests/normalize/cases/event.event.many_load_2.json diff --git a/tests/normalize/cases/event_slot_session_metadata_1.extracted.json b/tests/normalize/cases/event.event.slot_session_metadata_1.json similarity index 100% rename from tests/normalize/cases/event_slot_session_metadata_1.extracted.json rename to tests/normalize/cases/event.event.slot_session_metadata_1.json diff --git a/tests/normalize/cases/event_user_load_1.extracted.json b/tests/normalize/cases/event.event.user_load_1.json similarity index 100% rename from tests/normalize/cases/event_user_load_1.extracted.json rename to tests/normalize/cases/event.event.user_load_1.json diff --git a/tests/normalize/cases/event_user_load_v228_1.extracted.json b/tests/normalize/cases/event.event.user_load_v228_1.json similarity index 100% rename from tests/normalize/cases/event_user_load_v228_1.extracted.json rename to tests/normalize/cases/event.event.user_load_v228_1.json diff --git a/tests/normalize/mock_rasa_json_normalizer.py b/tests/normalize/mock_rasa_json_normalizer.py index f6fbde5d59..3975c484b9 100644 --- a/tests/normalize/mock_rasa_json_normalizer.py +++ b/tests/normalize/mock_rasa_json_normalizer.py @@ -1,17 +1,16 @@ from dlt.common.normalizers.json import TNormalizedRowIterator from dlt.common.schema import Schema from dlt.common.normalizers.json.relational import normalize_data_item as relational_normalize, extend_schema -from dlt.common.sources import with_table_name from dlt.common.typing import TDataItem -def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str) -> TNormalizedRowIterator: +def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: if schema.name == "event": # this emulates rasa parser on standard parser - event = {"sender_id": source_event["sender_id"], "timestamp": source_event["timestamp"]} - yield from relational_normalize(schema, event, load_id) + event = {"sender_id": source_event["sender_id"], "timestamp": source_event["timestamp"], "type": source_event["event"]} + yield from relational_normalize(schema, event, load_id, table_name) # add table name which is "event" field in RASA OSS - with_table_name(source_event, "event_" + source_event["event"]) - - # will generate tables properly - yield from relational_normalize(schema, source_event, load_id) + yield from relational_normalize(schema, source_event, load_id, table_name + "_" + source_event["event"]) + else: + # will generate tables properly + yield from relational_normalize(schema, source_event, load_id, table_name) diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 5edb12244e..c68451edb8 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -1,19 +1,17 @@ from typing import Dict, List, Sequence -import os import pytest -import shutil from fnmatch import fnmatch from prometheus_client import CollectorRegistry +from multiprocessing import get_start_method, Pool from multiprocessing.dummy import Pool as ThreadPool from dlt.common import json from dlt.common.utils import uniq_id from dlt.common.typing import StrAny -from dlt.common.file_storage import FileStorage from dlt.common.schema import TDataType from dlt.common.storages import NormalizeStorage, LoadStorage -from dlt.extract.extractor_storage import ExtractorStorageBase +from experiments.pipeline.extract import ExtractorStorage from dlt.normalize import Normalize, configuration as normalize_configuration, __version__ from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES @@ -42,7 +40,7 @@ def init_normalize(default_schemas_path: str = None) -> Normalize: initial = {"IMPORT_SCHEMA_PATH": default_schemas_path, "EXTERNAL_SCHEMA_FORMAT": "json"} n = Normalize(normalize_configuration(initial), CollectorRegistry()) # set jsonl as default writer - n.load_storage.preferred_file_format = n.CONFIG.LOADER_FILE_FORMAT = "jsonl" + n.load_storage.loader_file_format = n.CONFIG.LOADER_FILE_FORMAT = "jsonl" return n @@ -56,13 +54,8 @@ def test_intialize(rasa_normalize: Normalize) -> None: pass -# def test_empty_schema_name(raw_normalize: Normalize) -> None: -# schema = raw_normalize.load_or_create_schema("") -# assert schema.name == "" - - def test_normalize_single_user_event_jsonl(raw_normalize: Normalize) -> None: - expected_tables, load_files = normalize_event_user(raw_normalize, "event_user_load_1", EXPECTED_USER_TABLES) + expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) # load, parse and verify jsonl for expected_table in expected_tables: expect_lines_file(raw_normalize.load_storage, load_files[expected_table]) @@ -82,8 +75,8 @@ def test_normalize_single_user_event_jsonl(raw_normalize: Normalize) -> None: def test_normalize_single_user_event_insert(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.preferred_file_format = raw_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" - expected_tables, load_files = normalize_event_user(raw_normalize, "event_user_load_1", EXPECTED_USER_TABLES) + raw_normalize.load_storage.loader_file_format = raw_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" + expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) # verify values line for expected_table in expected_tables: expect_lines_file(raw_normalize.load_storage, load_files[expected_table]) @@ -99,7 +92,7 @@ def test_normalize_single_user_event_insert(raw_normalize: Normalize) -> None: def test_normalize_filter_user_event(rasa_normalize: Normalize) -> None: - load_id = normalize_cases(rasa_normalize, ["event_user_load_v228_1"]) + load_id = normalize_cases(rasa_normalize, ["event.event.user_load_v228_1"]) load_files = expect_load_package( rasa_normalize.load_storage, load_id, @@ -115,7 +108,7 @@ def test_normalize_filter_user_event(rasa_normalize: Normalize) -> None: def test_normalize_filter_bot_event(rasa_normalize: Normalize) -> None: - load_id = normalize_cases(rasa_normalize, ["event_bot_load_metadata_1"]) + load_id = normalize_cases(rasa_normalize, ["event.event.bot_load_metadata_2987398237498798"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_bot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_bot"], 0) assert lines == 1 @@ -125,7 +118,7 @@ def test_normalize_filter_bot_event(rasa_normalize: Normalize) -> None: def test_preserve_slot_complex_value_json_l(rasa_normalize: Normalize) -> None: - load_id = normalize_cases(rasa_normalize, ["event_slot_session_metadata_1"]) + load_id = normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_slot"], 0) assert lines == 1 @@ -138,8 +131,8 @@ def test_preserve_slot_complex_value_json_l(rasa_normalize: Normalize) -> None: def test_preserve_slot_complex_value_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.preferred_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" - load_id = normalize_cases(rasa_normalize, ["event_slot_session_metadata_1"]) + rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" + load_id = normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_slot"], 2) assert lines == 3 @@ -151,33 +144,50 @@ def test_preserve_slot_complex_value_insert(rasa_normalize: Normalize) -> None: def test_normalize_raw_no_type_hints(raw_normalize: Normalize) -> None: - normalize_event_user(raw_normalize, "event_user_load_1", EXPECTED_USER_TABLES) + normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) assert_timestamp_data_type(raw_normalize.load_storage, "double") def test_normalize_raw_type_hints(rasa_normalize: Normalize) -> None: - normalize_cases(rasa_normalize, ["event_user_load_1"]) + normalize_cases(rasa_normalize, ["event.event.user_load_1"]) assert_timestamp_data_type(rasa_normalize.load_storage, "timestamp") def test_normalize_many_events_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.preferred_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" - load_id = normalize_cases(rasa_normalize, ["event_many_load_2", "event_user_load_1"]) + rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" + load_id = normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) # return first values line from event_user file event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event"], 4) + # 2 lines header + 3 lines data assert lines == 5 assert f"'{load_id}'" in event_text +def test_normalize_many_events(rasa_normalize: Normalize) -> None: + load_id = normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) + expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] + load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) + # return first values line from event_user file + event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event"], 2) + # 3 lines data + assert lines == 3 + assert f"{load_id}" in event_text + + def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.preferred_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" - copy_cases( + rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" + extract_cases( rasa_normalize.normalize_storage, - ["event_many_load_2", "event_user_load_1", "ethereum_blocks_9c1d9b504ea240a482b007788d5cd61c_2"] + ["event.event.many_load_2", "event.event.user_load_1", "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2"] ) - rasa_normalize.run(ThreadPool(processes=4)) + if get_start_method() != "fork": + # windows, mac os do not support fork + rasa_normalize.run(ThreadPool(processes=4)) + else: + # linux does so use real process pool in tests + rasa_normalize.run(Pool(processes=4)) # must have two loading groups with model and event schemas loads = rasa_normalize.load_storage.list_packages() assert len(loads) == 2 @@ -196,8 +206,8 @@ def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: def test_normalize_typed_json(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.preferred_file_format = raw_normalize.CONFIG.LOADER_FILE_FORMAT = "jsonl" - extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], "special") + raw_normalize.load_storage.loader_file_format = raw_normalize.CONFIG.LOADER_FILE_FORMAT = "jsonl" + extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], "special", "special") raw_normalize.run(ThreadPool(processes=1)) loads = raw_normalize.load_storage.list_packages() assert len(loads) == 1 @@ -222,17 +232,13 @@ def test_normalize_typed_json(raw_normalize: Normalize) -> None: "event__parse_data__response_selector__default__response__responses"] -def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str) -> None: - extractor = ExtractorStorageBase("1.0.0", True, FileStorage(os.path.join(TEST_STORAGE_ROOT, "extractor"), makedirs=True), normalize_storage) - load_id = uniq_id() - extractor.save_json(f"{load_id}.json", items) - extractor.commit_events( - schema_name, - extractor.storage.make_full_path(f"{load_id}.json"), - "items", - len(items), - load_id - ) +def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str, table_name: str) -> None: + extractor = ExtractorStorage(normalize_storage.CONFIG) + extract_id = extractor.create_extract_id() + extractor.write_data_item(extract_id, schema_name, table_name, items, None) + extractor.close_writers(extract_id) + extractor.commit_extract_files(extract_id) + def normalize_event_user(normalize: Normalize, case: str, expected_user_tables: List[str] = None) -> None: expected_user_tables = expected_user_tables or EXPECTED_USER_TABLES_RASA_NORMALIZER @@ -241,21 +247,23 @@ def normalize_event_user(normalize: Normalize, case: str, expected_user_tables: def normalize_cases(normalize: Normalize, cases: Sequence[str]) -> str: - copy_cases(normalize.normalize_storage, cases) + extract_cases(normalize.normalize_storage, cases) load_id = uniq_id() normalize.load_storage.create_temp_load_package(load_id) # pool not required for map_single - dest_cases = [f"{NormalizeStorage.EXTRACTED_FOLDER}/{c}.extracted.json" for c in cases] + dest_cases = normalize.normalize_storage.storage.list_folder_files(NormalizeStorage.EXTRACTED_FOLDER) # [f"{NormalizeStorage.EXTRACTED_FOLDER}/{c}.extracted.json" for c in cases] # create schema if it does not exist - normalize.load_or_create_schema("event") + Normalize.load_or_create_schema(normalize.schema_storage, "event") normalize.spool_files("event", load_id, normalize.map_single, dest_cases) return load_id -def copy_cases(normalize_storage: NormalizeStorage, cases: Sequence[str]) -> None: +def extract_cases(normalize_storage: NormalizeStorage, cases: Sequence[str]) -> None: for case in cases: - event_user_path = json_case_path(f"{case}.extracted") - shutil.copy(event_user_path, normalize_storage.storage.make_full_path(NormalizeStorage.EXTRACTED_FOLDER)) + schema_name, table_name, _ = NormalizeStorage.parse_normalize_file_name(case + ".jsonl") + with open(json_case_path(case), "r") as f: + items = json.load(f) + extract_items(normalize_storage, items, schema_name, table_name) def expect_load_package(load_storage: LoadStorage, load_id: str, expected_tables: Sequence[str]) -> Dict[str, str]: From 52eebb9975ea8d2cd04a040ebf756d7e3a05d966 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Sep 2022 17:55:02 +0200 Subject: [PATCH 17/66] adds decorator for Pool workers to pass configuration types across process boundary --- dlt/common/runners/pool_runner.py | 4 +-- dlt/common/runners/runnable.py | 47 +++++++++++++++++++++++++-- dlt/common/utils.py | 2 +- tests/common/runners/test_runnable.py | 45 ++++++++++++++++++++++++- 4 files changed, 91 insertions(+), 7 deletions(-) diff --git a/dlt/common/runners/pool_runner.py b/dlt/common/runners/pool_runner.py index c0e159a1af..d5409f974f 100644 --- a/dlt/common/runners/pool_runner.py +++ b/dlt/common/runners/pool_runner.py @@ -1,14 +1,12 @@ -import argparse import multiprocessing from prometheus_client import Counter, Gauge, Summary, CollectorRegistry, REGISTRY -from typing import Callable, Dict, NamedTuple, Optional, Type, TypeVar, Union, cast +from typing import Callable, Dict, Type, Union, cast from multiprocessing.pool import ThreadPool, Pool from dlt.common import logger, signals from dlt.common.runners.runnable import Runnable, TPool from dlt.common.time import sleep from dlt.common.telemetry import TRunHealth, TRunMetrics, get_logging_extras, get_metrics_from_prometheus -from dlt.common.utils import str2bool from dlt.common.exceptions import SignalReceivedException, TimeRangeExhaustedException, UnsupportedProcessStartMethodException from dlt.common.configuration import PoolRunnerConfiguration diff --git a/dlt/common/runners/runnable.py b/dlt/common/runners/runnable.py index ad73cce26d..053d275f3c 100644 --- a/dlt/common/runners/runnable.py +++ b/dlt/common/runners/runnable.py @@ -1,11 +1,15 @@ +import inspect from abc import ABC, abstractmethod from functools import wraps -from typing import Any, Dict, Type, TypeVar, TYPE_CHECKING, Union, Generic +from typing import Any, Dict, Mapping, Type, TypeVar, TYPE_CHECKING, Union, Generic, get_args from multiprocessing.pool import Pool from weakref import WeakValueDictionary +from dlt.common.configuration.run_configuration import BaseConfiguration +from dlt.common.typing import StrAny, TFun +from dlt.common.utils import uniq_id from dlt.common.telemetry import TRunMetrics -from dlt.common.typing import TFun +from dlt.common.configuration.utils import TConfiguration TPool = TypeVar("TPool", bound=Pool) @@ -56,3 +60,42 @@ def _wrap(rid: Union[int, Runnable[TPool]], *args: Any, **kwargs: Any) -> Any: return f(rid, *args, **kwargs) return _wrap # type: ignore + + +def configuredworker(f: TFun) -> TFun: + """Decorator for a process/thread pool worker function facilitates passing bound configuration type across the process boundary. It requires the first method + of the worker function to be annotated with type derived from Type[BaseConfiguration] and the worker function to be called (typically by the Pool class) with a + configuration values serialized to dict (via `as_dict` method). The decorator will synthesize a new derived type and apply the serialized value, mimicking the + original type to be transferred across the process boundary. + + Args: + f (TFun): worker function to be decorated + + Raises: + ValueError: raised when worker function signature does not contain required parameters or/and annotations + + + Returns: + TFun: wrapped worker function + """ + @wraps(f) + def _wrap(config: Union[StrAny, Type[BaseConfiguration]], *args: Any, **kwargs: Any) -> Any: + if isinstance(config, Mapping): + # worker process may run in separate process started with spawn and should not share any state with the parent process ie. global variables like config + # first function parameter should be of Type[BaseConfiguration] + sig = inspect.signature(f) + try: + first_param: inspect.Parameter = next(iter(sig.parameters.values())) + T = get_args(first_param.annotation)[0] + if not issubclass(T, BaseConfiguration): + raise ValueError(T) + except Exception: + raise ValueError(f"First parameter of wrapped worker method {f.__name__} must by annotated as Type[BaseConfiguration]") + CONFIG = type(f.__name__ + uniq_id(), (T, ), {}) + CONFIG.apply_dict(config) # type: ignore + config = CONFIG + + return f(config, *args, **kwargs) + + return _wrap # type: ignore + diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 835c57d58f..d8bc2d3491 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -93,7 +93,7 @@ def flatten_dicts_of_dicts(dicts: Mapping[str, Any]) -> Sequence[Any]: def tuplify_list_of_dicts(dicts: Sequence[DictStrAny]) -> Sequence[DictStrAny]: """ - Transform dicts with single key into {"key": orig_key, "value": orig_value} + Transform list of dictionaries with single key into single dictionary of {"key": orig_key, "value": orig_value} """ for d in dicts: if len(d) > 1: diff --git a/tests/common/runners/test_runnable.py b/tests/common/runners/test_runnable.py index dd883e0254..eef2771e23 100644 --- a/tests/common/runners/test_runnable.py +++ b/tests/common/runners/test_runnable.py @@ -1,7 +1,13 @@ import gc +from typing import Type +import pytest +from multiprocessing import get_start_method from multiprocessing.pool import Pool from multiprocessing.dummy import Pool as ThreadPool -import pytest + +from dlt.common.runners.runnable import configuredworker +from dlt.common.utils import uniq_id +from dlt.normalize.configuration import NormalizeConfiguration from tests.common.runners.utils import _TestRunnable from tests.utils import skipifspawn @@ -62,3 +68,40 @@ def test_weak_pool_ref() -> None: # weak reference will be removed from container with pytest.raises(KeyError): r = wref[rid] + + +def test_configuredworker() -> None: + + # call worker method with CONFIG values that should be restored into CONFIG type + config = NormalizeConfiguration.as_dict() + config["import_schema_path"] = "test_schema_path" + _worker_1(config, "PX1", par2="PX2") + + # may also be called directly + NormT = type("TEST_" + uniq_id(), (NormalizeConfiguration, ), {}) + NormT.IMPORT_SCHEMA_PATH = "test_schema_path" + _worker_1(NormT, "PX1", par2="PX2") + + # must also work across process boundary + with Pool(1) as p: + p.starmap(_worker_1, [(config, "PX1", "PX2")]) + + # wrong signature error + with pytest.raises(ValueError): + _wrong_worker_sig(config) + + +@configuredworker +def _wrong_worker_sig(CONFIG: NormalizeConfiguration) -> None: + pass + +@configuredworker +def _worker_1(CONFIG: Type[NormalizeConfiguration], par1: str, par2: str = "DEFAULT") -> None: + assert issubclass(CONFIG, NormalizeConfiguration) + # it is a subclass but not the same type + assert not CONFIG is NormalizeConfiguration + # check if config values are restored + assert CONFIG.IMPORT_SCHEMA_PATH == "test_schema_path" + # check if other parameters are correctly + assert par1 == "PX1" + assert par2 == "PX2" \ No newline at end of file From 693f89d1589636eac59aa64816c9984d781fbf2c Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Sep 2022 17:57:47 +0200 Subject: [PATCH 18/66] simplifies ExtractorStorage by moving data items storage out --- experiments/pipeline/extract.py | 39 ++++++++------------------------ experiments/pipeline/pipeline.py | 1 - experiments/pipeline/sources.py | 5 +++- 3 files changed, 14 insertions(+), 31 deletions(-) diff --git a/experiments/pipeline/extract.py b/experiments/pipeline/extract.py index 8ed57c1486..e3209b64dc 100644 --- a/experiments/pipeline/extract.py +++ b/experiments/pipeline/extract.py @@ -1,13 +1,10 @@ import os -from typing import Dict, List, Sequence, Type +from typing import List, Type -from dlt.common import logger -from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.utils import uniq_id from dlt.common.sources import TDirectDataItem, TDataItem from dlt.common.schema import utils, TSchemaUpdate -from dlt.common.data_writers import BufferedDataWriter -from dlt.common.storages import NormalizeStorage +from dlt.common.storages import NormalizeStorage, DataItemStorage from dlt.common.configuration import NormalizeVolumeConfiguration @@ -15,14 +12,13 @@ from experiments.pipeline.sources import DltResource, DltSource -class ExtractorStorage(NormalizeStorage): +class ExtractorStorage(DataItemStorage, NormalizeStorage): EXTRACT_FOLDER = "extract" - EXTRACT_FILE_NAME_TEMPLATE = "" def __init__(self, C: Type[NormalizeVolumeConfiguration]) -> None: - super().__init__(False, C) + # data item storage with jsonl with pua encoding + super().__init__("puae-jsonl", False, C) self.initialize_storage() - self.buffered_writers: Dict[str, BufferedDataWriter] = {} def initialize_storage(self) -> None: self.storage.create_folder(ExtractorStorage.EXTRACT_FOLDER, exists_ok=True) @@ -45,25 +41,9 @@ def commit_extract_files(self, extract_id: str, with_delete: bool = True) -> Non if with_delete: self.storage.delete_folder(extract_path, recursively=True) - def write_data_item(self, extract_id: str, schema_name: str, table_name: str, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: - # unique writer id - writer_id = f"{extract_id}.{schema_name}.{table_name}" - writer = self.buffered_writers.get(writer_id, None) - if not writer: - # assign a jsonl writer with pua encoding for each table, use %s for file id to create required template - template = NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s") - path = self.storage.make_full_path(os.path.join(self._get_extract_path(extract_id), template)) - writer = BufferedDataWriter("puae-jsonl", path) - self.buffered_writers[writer_id] = writer - # write item(s) - writer.write_data_item(item, columns) - - def close_writers(self, extract_id: str) -> None: - # flush and close all files - for name, writer in self.buffered_writers.items(): - if name.startswith(extract_id): - logger.debug(f"Closing writer for {name} with file {writer._file} and actual name {writer._file_name}") - writer.close_writer() + def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: + template = NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s") + return self.storage.make_full_path(os.path.join(self._get_extract_path(load_id), template)) def _get_extract_path(self, extract_id: str) -> str: return os.path.join(ExtractorStorage.EXTRACT_FOLDER, extract_id) @@ -78,6 +58,8 @@ def _write_item(table_name: str, item: TDirectDataItem) -> None: # normalize table name before writing so the name match the name in schema # note: normalize function should be cached so there's almost no penalty on frequent calling # note: column schema is not required for jsonl writer used here + # TODO: consider dropping DLT_METADATA_FIELD in all items before writing, this however takes CPU time + # event.pop(DLT_METADATA_FIELD, None) # type: ignore storage.write_data_item(extract_id, schema.name, schema.normalize_table_name(table_name), item, None) def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: @@ -100,7 +82,6 @@ def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: # yield from all selected pipes for pipe_item in PipeIterator.from_pipes(source.pipes): # get partial table from table template - print(pipe_item) resource = source.resource_by_pipe(pipe_item.pipe) if resource._table_name_hint_fun: if isinstance(pipe_item.item, List): diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py index 54c2113078..b14cb43ba4 100644 --- a/experiments/pipeline/pipeline.py +++ b/experiments/pipeline/pipeline.py @@ -268,7 +268,6 @@ def normalize(self, dry_run: bool = False, workers: int = 1, max_events_in_chunk # set parameters to be passed to config normalize = self._configure_normalize({ "WORKERS": workers, - "MAX_EVENTS_IN_CHUNK": max_events_in_chunk, "POOL_TYPE": "thread" if workers == 1 else "process" }) try: diff --git a/experiments/pipeline/sources.py b/experiments/pipeline/sources.py index 95a55c1e0c..4b85646909 100644 --- a/experiments/pipeline/sources.py +++ b/experiments/pipeline/sources.py @@ -29,6 +29,7 @@ def __init__(self, name: str, table_schema_template: TTableSchemaTemplate = None self._table_name_hint_fun: TFunDataItemDynHint = None self._table_has_other_dynamic_hints: bool = False self._table_schema_template: TTableSchemaTemplate = None + self._table_schema: TPartialTableSchema = None if table_schema_template: self._set_template(table_schema_template) @@ -36,7 +37,9 @@ def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: if not self._table_schema_template: # if table template is not present, generate partial table from name - return new_table(self.name) + if not self._table_schema: + self._table_schema = new_table(self.name) + return self._table_schema def _resolve_hint(hint: Union[Any, TFunDataItemDynHint]) -> Any: if callable(hint): From 1ce4735861977a72721c20512f7a1f5961145008 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 28 Sep 2022 22:14:05 +0200 Subject: [PATCH 19/66] refactors configurations: uses dataclass based instances, no production configs, snake case field names, recursive binding to values --- Makefile | 4 +- dlt/cli/dlt.py | 1 - dlt/common/configuration/__init__.py | 9 +- .../configuration/base_configuration.py | 108 ++++ dlt/common/configuration/exceptions.py | 16 +- .../configuration/gcp_client_credentials.py | 53 +- .../load_volume_configuration.py | 12 +- .../normalize_volume_configuration.py | 11 +- .../pool_runner_configuration.py | 22 +- .../configuration/postgres_credentials.py | 38 +- dlt/common/configuration/providers/environ.py | 7 +- dlt/common/configuration/run_configuration.py | 87 +--- .../schema_volume_configuration.py | 17 +- dlt/common/configuration/utils.py | 184 +++---- dlt/common/logger.py | 36 +- dlt/common/runners/init.py | 2 +- dlt/common/runners/pool_runner.py | 34 +- dlt/common/runners/runnable.py | 73 ++- dlt/common/storages/live_schema_storage.py | 8 +- dlt/common/storages/load_storage.py | 8 +- dlt/common/storages/normalize_storage.py | 6 +- dlt/common/storages/schema_storage.py | 44 +- dlt/common/typing.py | 3 +- dlt/common/utils.py | 15 +- dlt/dbt_runner/configuration.py | 74 ++- dlt/dbt_runner/runner.py | 50 +- dlt/helpers/streamlit.py | 10 +- dlt/load/bigquery/client.py | 38 +- dlt/load/bigquery/configuration.py | 19 +- dlt/load/client_base.py | 2 +- dlt/load/configuration.py | 32 +- dlt/load/dummy/client.py | 16 +- dlt/load/dummy/configuration.py | 21 +- dlt/load/load.py | 10 +- dlt/load/redshift/client.py | 14 +- dlt/load/redshift/configuration.py | 13 +- dlt/normalize/configuration.py | 20 +- dlt/normalize/normalize.py | 14 +- dlt/pipeline/pipeline.py | 19 +- experiments/pipeline/configuration.py | 113 +++- experiments/pipeline/extract.py | 4 +- experiments/pipeline/pipeline.py | 19 +- poetry.lock | 39 +- pyproject.toml | 1 - tests/common/runners/test_runnable.py | 28 +- tests/common/runners/test_runners.py | 35 +- tests/common/schema/test_schema.py | 5 +- tests/common/storages/test_loader_storage.py | 2 +- tests/common/storages/test_schema_storage.py | 22 +- tests/common/test_configuration.py | 481 ++++++++++-------- tests/common/test_logging.py | 41 +- tests/conftest.py | 17 + tests/dbt_runner/test_runner_bigquery.py | 4 +- tests/dbt_runner/test_runner_redshift.py | 24 +- tests/dbt_runner/utils.py | 4 +- tests/load/bigquery/test_bigquery_client.py | 2 +- .../bigquery/test_bigquery_table_builder.py | 10 +- .../redshift/test_redshift_table_builder.py | 8 +- tests/load/test_client.py | 4 +- tests/load/test_dummy_client.py | 38 +- tests/load/utils.py | 6 +- tests/normalize/test_normalize.py | 18 +- tests/utils.py | 16 +- 63 files changed, 1128 insertions(+), 963 deletions(-) create mode 100644 dlt/common/configuration/base_configuration.py create mode 100644 tests/conftest.py diff --git a/Makefile b/Makefile index fb1ca191d1..d5108293b8 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ dev: has-poetry lint: ./check-package.sh - poetry run mypy --config-file mypy.ini dlt examples experiments/pipeline + poetry run mypy --config-file mypy.ini dlt poetry run flake8 --max-line-length=200 examples dlt poetry run flake8 --max-line-length=200 tests # dlt/pipeline dlt/common/schema dlt/common/normalizers @@ -50,7 +50,7 @@ lint-security: reset-test-storage: -rm -r _storage mkdir _storage - python3 test/tools/create_storages.py + python3 tests/tools/create_storages.py recreate-compiled-deps: poetry export -f requirements.txt --output _gen_requirements.txt --without-hashes --extras gcp --extras redshift diff --git a/dlt/cli/dlt.py b/dlt/cli/dlt.py index 33759f802e..43571860c2 100644 --- a/dlt/cli/dlt.py +++ b/dlt/cli/dlt.py @@ -7,7 +7,6 @@ from dlt.cli import TRunnerArgs from dlt.common.schema import Schema from dlt.common.typing import DictStrAny -from dlt.common.utils import str2bool from dlt.pipeline import Pipeline, PostgresPipelineCredentials diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 0be19399ce..aefd544548 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,7 +1,8 @@ -from .run_configuration import RunConfiguration, BaseConfiguration, CredentialsConfiguration # noqa: F401 -from .normalize_volume_configuration import NormalizeVolumeConfiguration, ProductionNormalizeVolumeConfiguration # noqa: F401 -from .load_volume_configuration import LoadVolumeConfiguration, ProductionLoadVolumeConfiguration # noqa: F401 -from .schema_volume_configuration import SchemaVolumeConfiguration, ProductionSchemaVolumeConfiguration # noqa: F401 +from .run_configuration import RunConfiguration # noqa: F401 +from .base_configuration import BaseConfiguration, CredentialsConfiguration, configspec # noqa: F401 +from .normalize_volume_configuration import NormalizeVolumeConfiguration # noqa: F401 +from .load_volume_configuration import LoadVolumeConfiguration # noqa: F401 +from .schema_volume_configuration import SchemaVolumeConfiguration # noqa: F401 from .pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 from .gcp_client_credentials import GcpClientCredentials # noqa: F401 from .postgres_credentials import PostgresCredentials # noqa: F401 diff --git a/dlt/common/configuration/base_configuration.py b/dlt/common/configuration/base_configuration.py new file mode 100644 index 0000000000..7b92a4aa57 --- /dev/null +++ b/dlt/common/configuration/base_configuration.py @@ -0,0 +1,108 @@ +import dataclasses +from typing import Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING + +if TYPE_CHECKING: + TDtcField = dataclasses.Field[Any] +else: + TDtcField = dataclasses.Field + +from dlt.common.typing import TAny +from dlt.common.configuration.exceptions import ConfigFieldTypingMissingException + + +def configspec(cls: Type[TAny] = None, /, *, init: bool = False) -> Type[TAny]: + + def wrap(cls: Type[TAny]) -> Type[TAny]: + # get all annotations without corresponding attributes and set them to None + for ann in cls.__annotations__: + if not hasattr(cls, ann): + setattr(cls, ann, None) + # get all attributes without corresponding annotations + for att_name, att in cls.__dict__.items(): + if not callable(att) and not att_name.startswith(("__", "_abc_impl")) and att_name not in cls.__annotations__: + raise ConfigFieldTypingMissingException(att_name, cls) + return dataclasses.dataclass(cls, init=init, eq=False) # type: ignore + + # called with parenthesis + if cls is None: + return wrap # type: ignore + + return wrap(cls) + + +@configspec +class BaseConfiguration(MutableMapping[str, Any]): + + # will be set to true if not all config entries could be resolved + __is_partial__: bool = dataclasses.field(default = True, init=False, repr=False) + # namespace used by config providers when searching for keys + __namespace__: str = dataclasses.field(default = None, init=False, repr=False) + __dataclass_fields__: Dict[str, TDtcField] + + def __init__(self) -> None: + self.__ignore_set_unknown_keys = False + + def from_native_representation(self, native_value: Any) -> None: + """Initialize the configuration fields by parsing the `initial_value` which should be a native representation of the configuration + or credentials, for example database connection string or JSON serialized GCP service credentials file. + + Args: + initial_value (Any): A native representation of the configuration + + Raises: + NotImplementedError: This configuration does not have a native representation + ValueError: The value provided cannot be parsed as native representation + """ + raise NotImplementedError() + + def to_native_representation(self) -> Any: + """Represents the configuration instance in its native form ie. database connection string or JSON serialized GCP service credentials file. + + Raises: + NotImplementedError: This configuration does not have a native representation + + Returns: + Any: A native representation of the configuration + """ + raise NotImplementedError() + + # implement dictionary-compatible interface on top of dataclass + + def __getitem__(self, __key: str) -> Any: + if self._has_attr(__key): + return getattr(self, __key) + else: + raise KeyError(__key) + + def __setitem__(self, __key: str, __value: Any) -> None: + if self._has_attr(__key): + setattr(self, __key, __value) + else: + if not self.__ignore_set_unknown_keys: + raise KeyError(__key) + + def __delitem__(self, __key: str) -> None: + raise NotImplementedError("Configuration fields cannot be deleted") + + def __iter__(self) -> Iterator[str]: + return filter(lambda k: not k.startswith("__"), self.__dataclass_fields__.__iter__()) + + def __len__(self) -> int: + return sum(1 for _ in self.__iter__()) + + def update(self, other: Any = (), /, **kwds: Any) -> None: + try: + self.__ignore_set_unknown_keys = True + super().update(other, **kwds) + finally: + self.__ignore_set_unknown_keys = False + + # helper functions + + def _has_attr(self, __key: str) -> bool: + return __key in self.__dataclass_fields__ and not __key.startswith("__") + + +@configspec +class CredentialsConfiguration(BaseConfiguration): + pass diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index eb859d1b30..a1ee0c7bce 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -1,4 +1,4 @@ -from typing import Iterable, Union +from typing import Any, Iterable, Type, Union from dlt.common.exceptions import DltException @@ -8,6 +8,11 @@ def __init__(self, msg: str) -> None: super().__init__(msg) +class ConfigurationWrongTypeException(ConfigurationException): + def __init__(self, _typ: type) -> None: + super().__init__(f"Invalid configuration instance type {_typ}. Configuration instances must derive from BaseConfiguration.") + + class ConfigEntryMissingException(ConfigurationException): """thrown when not all required config elements are present""" @@ -46,3 +51,12 @@ class ConfigFileNotFoundException(ConfigurationException): def __init__(self, path: str) -> None: super().__init__(f"Missing config file in {path}") + + +class ConfigFieldTypingMissingException(ConfigurationException): + """thrown when configuration specification does not have type annotation""" + + def __init__(self, field_name: str, typ_: Type[Any]) -> None: + self.field_name = field_name + self.typ_ = typ_ + super().__init__(f"Field {field_name} on configspec {typ_} does not provide required type annotation") diff --git a/dlt/common/configuration/gcp_client_credentials.py b/dlt/common/configuration/gcp_client_credentials.py index 5388eb6e9f..0d42aad990 100644 --- a/dlt/common/configuration/gcp_client_credentials.py +++ b/dlt/common/configuration/gcp_client_credentials.py @@ -1,33 +1,44 @@ +from typing import Any +from dlt.common import json + from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration import CredentialsConfiguration +from dlt.common.configuration.base_configuration import CredentialsConfiguration, configspec +@configspec class GcpClientCredentials(CredentialsConfiguration): - __namespace__: str = "GCP" + __namespace__: str = "gcp" + + project_id: str = None + type: str = "service_account" # noqa: A003 + private_key: TSecretValue = None + location: str = "US" + token_uri: str = "https://oauth2.googleapis.com/token" + client_email: str = None - PROJECT_ID: str = None - CRED_TYPE: str = "service_account" - PRIVATE_KEY: TSecretValue = None - LOCATION: str = "US" - TOKEN_URI: str = "https://oauth2.googleapis.com/token" - CLIENT_EMAIL: str = None + http_timeout: float = 15.0 + retry_deadline: float = 600 - HTTP_TIMEOUT: float = 15.0 - RETRY_DEADLINE: float = 600 + def from_native_representation(self, initial_value: Any) -> None: + if not isinstance(initial_value, str): + raise ValueError(initial_value) + try: + service_dict = json.loads(initial_value) + self.update(service_dict) + except Exception: + raise ValueError(initial_value) - @classmethod - def check_integrity(cls) -> None: - if cls.PRIVATE_KEY and cls.PRIVATE_KEY[-1] != "\n": + def check_integrity(self) -> None: + if self.private_key and self.private_key[-1] != "\n": # must end with new line, otherwise won't be parsed by Crypto - cls.PRIVATE_KEY = TSecretValue(cls.PRIVATE_KEY + "\n") + self.private_key = TSecretValue(self.private_key + "\n") - @classmethod - def as_credentials(cls) -> StrAny: + def to_native_representation(self) -> StrAny: return { - "type": cls.CRED_TYPE, - "project_id": cls.PROJECT_ID, - "private_key": cls.PRIVATE_KEY, - "token_uri": cls.TOKEN_URI, - "client_email": cls.CLIENT_EMAIL + "type": self.type, + "project_id": self.project_id, + "private_key": self.private_key, + "token_uri": self.token_uri, + "client_email": self.client_email } \ No newline at end of file diff --git a/dlt/common/configuration/load_volume_configuration.py b/dlt/common/configuration/load_volume_configuration.py index 41e1746769..9f2fdada3d 100644 --- a/dlt/common/configuration/load_volume_configuration.py +++ b/dlt/common/configuration/load_volume_configuration.py @@ -1,11 +1,7 @@ -import os - -from dlt.common.configuration.run_configuration import BaseConfiguration +from dlt.common.configuration.base_configuration import BaseConfiguration, configspec +@configspec class LoadVolumeConfiguration(BaseConfiguration): - LOAD_VOLUME_PATH: str = os.path.join("_storage", "load") # path to volume where files to be loaded to analytical storage are stored - DELETE_COMPLETED_JOBS: bool = False # if set to true the folder with completed jobs will be deleted - -class ProductionLoadVolumeConfiguration(LoadVolumeConfiguration): - LOAD_VOLUME_PATH: str = None + load_volume_path: str = None # path to volume where files to be loaded to analytical storage are stored + delete_completed_jobs: bool = False # if set to true the folder with completed jobs will be deleted diff --git a/dlt/common/configuration/normalize_volume_configuration.py b/dlt/common/configuration/normalize_volume_configuration.py index 3e3b8c34d6..12ff684c54 100644 --- a/dlt/common/configuration/normalize_volume_configuration.py +++ b/dlt/common/configuration/normalize_volume_configuration.py @@ -1,11 +1,6 @@ -import os - -from dlt.common.configuration import BaseConfiguration +from dlt.common.configuration.base_configuration import BaseConfiguration, configspec +@configspec class NormalizeVolumeConfiguration(BaseConfiguration): - NORMALIZE_VOLUME_PATH: str = os.path.join("_storage", "normalize") # path to volume where normalized loader files will be stored - - -class ProductionNormalizeVolumeConfiguration(NormalizeVolumeConfiguration): - NORMALIZE_VOLUME_PATH: str = None + normalize_volume_path: str = None # path to volume where normalized loader files will be stored diff --git a/dlt/common/configuration/pool_runner_configuration.py b/dlt/common/configuration/pool_runner_configuration.py index 9e900c3fca..e5ef5665d8 100644 --- a/dlt/common/configuration/pool_runner_configuration.py +++ b/dlt/common/configuration/pool_runner_configuration.py @@ -1,16 +1,18 @@ from typing import Literal, Optional -from dlt.common.configuration import RunConfiguration + +from dlt.common.configuration.run_configuration import RunConfiguration, configspec TPoolType = Literal["process", "thread", "none"] +@configspec class PoolRunnerConfiguration(RunConfiguration): - POOL_TYPE: TPoolType = None # type of pool to run, must be set in derived configs - WORKERS: Optional[int] = None # how many threads/processes in the pool - RUN_SLEEP: float = 0.5 # how long to sleep between runs with workload, seconds - RUN_SLEEP_IDLE: float = 1.0 # how long to sleep when no more items are pending, seconds - RUN_SLEEP_WHEN_FAILED: float = 1.0 # how long to sleep between the runs when failed - IS_SINGLE_RUN: bool = False # should run only once until all pending data is processed, and exit - WAIT_RUNS: int = 0 # how many runs to wait for first data coming in is IS_SINGLE_RUN is set - EXIT_ON_EXCEPTION: bool = False # should exit on exception - STOP_AFTER_RUNS: int = 10000 # will stop runner with exit code -2 after so many runs, that prevents memory fragmentation + pool_type: TPoolType = None # type of pool to run, must be set in derived configs + workers: Optional[int] = None # how many threads/processes in the pool + run_sleep: float = 0.5 # how long to sleep between runs with workload, seconds + run_sleep_idle: float = 1.0 # how long to sleep when no more items are pending, seconds + run_sleep_when_failed: float = 1.0 # how long to sleep between the runs when failed + is_single_run: bool = False # should run only once until all pending data is processed, and exit + wait_runs: int = 0 # how many runs to wait for first data coming in is IS_SINGLE_RUN is set + exit_on_exception: bool = False # should exit on exception + stop_after_runs: int = 10000 # will stop runner with exit code -2 after so many runs, that prevents memory fragmentation diff --git a/dlt/common/configuration/postgres_credentials.py b/dlt/common/configuration/postgres_credentials.py index 4b090b6a65..62b639eac8 100644 --- a/dlt/common/configuration/postgres_credentials.py +++ b/dlt/common/configuration/postgres_credentials.py @@ -1,24 +1,30 @@ +from typing import Any + from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration import CredentialsConfiguration +from dlt.common.configuration.base_configuration import CredentialsConfiguration, configspec +@configspec class PostgresCredentials(CredentialsConfiguration): - __namespace__: str = "PG" + __namespace__: str = "pg" + + dbname: str = None + password: TSecretValue = None + user: str = None + host: str = None + port: int = 5439 + connect_timeout: int = 15 - DBNAME: str = None - PASSWORD: TSecretValue = None - USER: str = None - HOST: str = None - PORT: int = 5439 - CONNECT_TIMEOUT: int = 15 + def from_native_repesentation(self, initial_value: Any) -> None: + if not isinstance(initial_value, str): + raise ValueError(initial_value) + # TODO: parse postgres connection string + raise NotImplementedError() - @classmethod - def check_integrity(cls) -> None: - cls.DBNAME = cls.DBNAME.lower() - # cls.DEFAULT_DATASET = cls.DEFAULT_DATASET.lower() - cls.PASSWORD = TSecretValue(cls.PASSWORD.strip()) + def check_integrity(self) -> None: + self.dbname = self.dbname.lower() + self.password = TSecretValue(self.password.strip()) - @classmethod - def as_credentials(cls) -> StrAny: - return cls.as_dict() + def to_native_representation(self) -> StrAny: + raise NotImplementedError() diff --git a/dlt/common/configuration/providers/environ.py b/dlt/common/configuration/providers/environ.py index 03a12ba324..c5364f5973 100644 --- a/dlt/common/configuration/providers/environ.py +++ b/dlt/common/configuration/providers/environ.py @@ -8,11 +8,12 @@ def get_key_name(key: str, namespace: str = None) -> str: + # env key is always upper case if namespace: - return namespace + "__" + key + env_key = namespace + "__" + key else: - return key - + env_key = key + return env_key.upper() def get_key(key: str, hint: Type[Any], namespace: str = None) -> Optional[str]: # apply namespace to the key diff --git a/dlt/common/configuration/run_configuration.py b/dlt/common/configuration/run_configuration.py index 383de508f3..88af7699d1 100644 --- a/dlt/common/configuration/run_configuration.py +++ b/dlt/common/configuration/run_configuration.py @@ -1,73 +1,34 @@ -import randomname from os.path import isfile from typing import Any, Optional, Tuple, IO -from dlt.common.typing import StrAny, DictStrAny -from dlt.common.utils import encoding_for_mode +from dlt.common.utils import encoding_for_mode, entry_point_file_stem +from dlt.common.configuration.base_configuration import BaseConfiguration, configspec from dlt.common.configuration.exceptions import ConfigFileNotFoundException -DEVELOPMENT_CONFIG_FILES_STORAGE_PATH = "_storage/config/%s" -PRODUCTION_CONFIG_FILES_STORAGE_PATH = "/run/config/%s" - - -class BaseConfiguration: - - # will be set to true if not all config entries could be resolved - __is_partial__: bool = True - __namespace__: str = None - - @classmethod - def as_dict(config, lowercase: bool = True) -> DictStrAny: - may_lower = lambda k: k.lower() if lowercase else k - return {may_lower(k):getattr(config, k) for k in dir(config) if not callable(getattr(config, k)) and not k.startswith("__")} - - @classmethod - def apply_dict(config, values: StrAny, uppercase: bool = True, apply_non_spec: bool = False) -> None: - if not values: - return - - for k, v in values.items(): - k = k.upper() if uppercase else k - # overwrite only declared values - if not apply_non_spec and hasattr(config, k): - setattr(config, k, v) - - -class CredentialsConfiguration(BaseConfiguration): - pass - +@configspec class RunConfiguration(BaseConfiguration): - PIPELINE_NAME: Optional[str] = None # the name of the component - SENTRY_DSN: Optional[str] = None # keep None to disable Sentry - PROMETHEUS_PORT: Optional[int] = None # keep None to disable Prometheus - LOG_FORMAT: str = '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}' - LOG_LEVEL: str = "DEBUG" - IS_DEVELOPMENT_CONFIG: bool = True - REQUEST_TIMEOUT: Tuple[int, int] = (15, 300) # default request timeout for all http clients - CONFIG_FILES_STORAGE_PATH: str = DEVELOPMENT_CONFIG_FILES_STORAGE_PATH - - @classmethod - def check_integrity(cls) -> None: - # generate random name if missing - if not cls.PIPELINE_NAME: - cls.PIPELINE_NAME = "dlt_" + randomname.get_name().replace("-", "_") - # if CONFIG_FILES_STORAGE_PATH not overwritten and we are in production mode - if cls.CONFIG_FILES_STORAGE_PATH == DEVELOPMENT_CONFIG_FILES_STORAGE_PATH and not cls.IS_DEVELOPMENT_CONFIG: - # set to mount where config files will be present - cls.CONFIG_FILES_STORAGE_PATH = PRODUCTION_CONFIG_FILES_STORAGE_PATH - - @classmethod - def has_configuration_file(cls, name: str) -> bool: - return isfile(cls.get_configuration_file_path(name)) - - @classmethod - def open_configuration_file(cls, name: str, mode: str) -> IO[Any]: - path = cls.get_configuration_file_path(name) - if not cls.has_configuration_file(name): + pipeline_name: Optional[str] = None # the name of the component + sentry_dsn: Optional[str] = None # keep None to disable Sentry + prometheus_port: Optional[int] = None # keep None to disable Prometheus + log_format: str = '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}' + log_level: str = "DEBUG" + request_timeout: Tuple[int, int] = (15, 300) # default request timeout for all http clients + config_files_storage_path: str = "/run/config/%s" + + def check_integrity(self) -> None: + # generate pipeline name from the entry point script name + if not self.pipeline_name: + self.pipeline_name = "dlt_" + (entry_point_file_stem() or "pipeline") + + def has_configuration_file(self, name: str) -> bool: + return isfile(self.get_configuration_file_path(name)) + + def open_configuration_file(self, name: str, mode: str) -> IO[Any]: + path = self.get_configuration_file_path(name) + if not self.has_configuration_file(name): raise ConfigFileNotFoundException(path) return open(path, mode, encoding=encoding_for_mode(mode)) - @classmethod - def get_configuration_file_path(cls, name: str) -> str: - return cls.CONFIG_FILES_STORAGE_PATH % name \ No newline at end of file + def get_configuration_file_path(self, name: str) -> str: + return self.config_files_storage_path % name diff --git a/dlt/common/configuration/schema_volume_configuration.py b/dlt/common/configuration/schema_volume_configuration.py index b3018a1782..6e9677ea14 100644 --- a/dlt/common/configuration/schema_volume_configuration.py +++ b/dlt/common/configuration/schema_volume_configuration.py @@ -1,17 +1,14 @@ from typing import Optional, Literal -from dlt.common.configuration import BaseConfiguration +from dlt.common.configuration.base_configuration import BaseConfiguration, configspec TSchemaFileFormat = Literal["json", "yaml"] +@configspec class SchemaVolumeConfiguration(BaseConfiguration): - SCHEMA_VOLUME_PATH: str = "_storage/schemas" # path to volume with default schemas - IMPORT_SCHEMA_PATH: Optional[str] = None # import schema from external location - EXPORT_SCHEMA_PATH: Optional[str] = None # export schema to external location - EXTERNAL_SCHEMA_FORMAT: TSchemaFileFormat = "yaml" # format in which to expect external schema - EXTERNAL_SCHEMA_FORMAT_REMOVE_DEFAULTS: bool = True # remove default values when exporting schema - - -class ProductionSchemaVolumeConfiguration(SchemaVolumeConfiguration): - SCHEMA_VOLUME_PATH: str = None + schema_volume_path: str = None # path to volume with default schemas + import_schema_path: Optional[str] = None # import schema from external location + export_schema_path: Optional[str] = None # export schema to external location + external_schema_format: TSchemaFileFormat = "yaml" # format in which to expect external schema + external_schema_format_remove_defaults: bool = True # remove default values when exporting schema diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py index cced372730..0c5f86698f 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/utils.py @@ -1,134 +1,104 @@ +import dataclasses +import inspect import sys import semver -from typing import Any, Dict, List, Mapping, Type, TypeVar, cast +from typing import Any, Dict, List, Mapping, Type, TypeVar -from dlt.common.typing import StrAny, is_optional_type, is_literal_type +from dlt.common.typing import is_optional_type, is_literal_type from dlt.common.configuration import BaseConfiguration from dlt.common.configuration.providers import environ -from dlt.common.configuration.exceptions import (ConfigEntryMissingException, - ConfigEnvValueCannotBeCoercedException) -from dlt.common.utils import uniq_id +from dlt.common.configuration.exceptions import (ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException) SIMPLE_TYPES: List[Any] = [int, bool, list, dict, tuple, bytes, set, float] # those types and Optionals of those types should not be passed to eval function NON_EVAL_TYPES = [str, None, Any] # allows to coerce (type1 from type2) ALLOWED_TYPE_COERCIONS = [(float, int), (str, int), (str, float)] -IS_DEVELOPMENT_CONFIG_KEY: str = "IS_DEVELOPMENT_CONFIG" CHECK_INTEGRITY_F: str = "check_integrity" -TConfiguration = TypeVar("TConfiguration", bound=Type[BaseConfiguration]) -# TODO: remove production configuration support -TProductionConfiguration = TypeVar("TProductionConfiguration", bound=Type[BaseConfiguration]) - - -def make_configuration(config: TConfiguration, - production_config: TProductionConfiguration, - initial_values: StrAny = None, - accept_partial: bool = False, - skip_subclass_check: bool = False) -> TConfiguration: - if not skip_subclass_check: - assert issubclass(production_config, config) - - final_config: TConfiguration = config if _is_development_config() else production_config - possible_keys_in_config = _get_config_attrs_with_hints(final_config) - # create dynamic class type to not touch original config variables - derived_config: TConfiguration = cast(TConfiguration, - type(final_config.__name__ + "_" + uniq_id(), (final_config, ), {}) - ) - # apply initial values while preserving hints - derived_config.apply_dict(initial_values) - - _apply_environ_to_config(derived_config, possible_keys_in_config) +TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) + + +def make_configuration(config: TConfiguration, initial_value: Any = None, accept_partial: bool = False) -> TConfiguration: + if not isinstance(config, BaseConfiguration): + raise ConfigurationWrongTypeException(type(config)) + + # get fields to resolve as per dataclasses PEP + fields = _get_resolvable_fields(config) + # parse initial value if possible + if initial_value is not None: + try: + config.from_native_representation(initial_value) + except (NotImplementedError, ValueError): + # if parsing failed and initial_values is dict then apply + if isinstance(initial_value, Mapping): + config.update(initial_value) + + _resolve_config_fields(config, fields, accept_partial) try: - _is_config_bounded(derived_config, possible_keys_in_config) - _check_configuration_integrity(derived_config) + _is_config_bounded(config, fields) + _check_configuration_integrity(config) # full configuration was resolved - derived_config.__is_partial__ = False + config.__is_partial__ = False except ConfigEntryMissingException: if not accept_partial: raise - _add_module_version(derived_config) - - return derived_config - + _add_module_version(config) -def is_direct_descendant(child: Type[Any], base: Type[Any]) -> bool: - # TODO: there may be faster way to get direct descendant that mro - # note: at index zero there's child - return base == type.mro(child)[1] + return config -def _is_development_config() -> bool: - # get from environment - is_dev_config: bool = None - try: - is_dev_config = _coerce_single_value(IS_DEVELOPMENT_CONFIG_KEY, environ.get_key(IS_DEVELOPMENT_CONFIG_KEY, bool), bool) - except ConfigEnvValueCannotBeCoercedException as coer_exc: - # pass for None: this key may not be present - if coer_exc.env_value is None: - pass - else: - # anything else that cannot corece must raise - raise - return True if is_dev_config is None else is_dev_config +# def is_direct_descendant(child: Type[Any], base: Type[Any]) -> bool: +# # TODO: there may be faster way to get direct descendant that mro +# # note: at index zero there's child +# return base == type.mro(child)[1] def _add_module_version(config: TConfiguration) -> None: try: v = sys._getframe(1).f_back.f_globals["__version__"] semver.VersionInfo.parse(v) - setattr(config, "_VERSION", v) # noqa: B010 + setattr(config, "_version", v) # noqa: B010 except KeyError: pass -def _apply_environ_to_config(config: TConfiguration, keys_in_config: Mapping[str, type]) -> None: - for key, hint in keys_in_config.items(): +def _resolve_config_fields(config: TConfiguration, fields: Mapping[str, type], accept_partial: bool) -> None: + for key, hint in fields.items(): + # get default value + resolved_value = getattr(config, key, None) + # resolve key value via active providers value = environ.get_key(key, hint, config.__namespace__) - if value is not None: - value_from_environment_variable = _coerce_single_value(key, value, hint) - # set value - setattr(config, key, value_from_environment_variable) - - -def _is_config_bounded(config: TConfiguration, keys_in_config: Mapping[str, type]) -> None: - # TODO: here we assume all keys are taken from environ provider, that should change when we introduce more providers - _unbound_attrs = [ - environ.get_key_name(key, config.__namespace__) for key in keys_in_config if getattr(config, key) is None and not is_optional_type(keys_in_config[key]) - ] - - if len(_unbound_attrs) > 0: - raise ConfigEntryMissingException(_unbound_attrs, config.__namespace__) - -def _check_configuration_integrity(config: TConfiguration) -> None: - # python multi-inheritance is cooperative and this would require that all configurations cooperatively - # call each other check_integrity. this is not at all possible as we do not know which configs in the end will - # be mixed together. - - # get base classes in order of derivation - mro = type.mro(config) - for c in mro: - # check if this class implements check_integrity (skip pure inheritance to not do double work) - if CHECK_INTEGRITY_F in c.__dict__ and callable(getattr(c, CHECK_INTEGRITY_F)): - # access unbounded __func__ to pass right class type so we check settings of the tip of mro - c.__dict__[CHECK_INTEGRITY_F].__func__(config) + # extract hint from Optional / NewType hints + hint = _extract_simple_type(hint) + # if hint is BaseConfiguration then resolve it recursively + if inspect.isclass(hint) and issubclass(hint, BaseConfiguration): + if isinstance(resolved_value, BaseConfiguration): + # if actual value is BaseConfiguration, resolve that instance + resolved_value = make_configuration(resolved_value, accept_partial=accept_partial) + else: + # create new instance and pass value from the provider as initial + resolved_value = make_configuration(hint(), initial_value=value or resolved_value, accept_partial=accept_partial) + else: + if value is not None: + resolved_value = _coerce_single_value(key, value, hint) + # set value resolved value + setattr(config, key, resolved_value) def _coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: try: - hint_primitive_type = _extract_simple_type(hint) - if hint_primitive_type not in NON_EVAL_TYPES: + if hint not in NON_EVAL_TYPES: # create primitive types out of strings typed_value = eval(value) # nosec # for primitive types check coercion - if hint_primitive_type in SIMPLE_TYPES and type(typed_value) != hint_primitive_type: + if hint in SIMPLE_TYPES and type(typed_value) != hint: # allow some exceptions coerce_exception = next( - (e for e in ALLOWED_TYPE_COERCIONS if e == (hint_primitive_type, type(typed_value))), None) + (e for e in ALLOWED_TYPE_COERCIONS if e == (hint, type(typed_value))), None) if coerce_exception: - return hint_primitive_type(typed_value) + return hint(typed_value) else: raise ConfigEnvValueCannotBeCoercedException(key, typed_value, hint) return typed_value @@ -140,18 +110,32 @@ def _coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: raise ConfigEnvValueCannotBeCoercedException(key, value, hint) from exc -def _get_config_attrs_with_hints(config: TConfiguration) -> Dict[str, type]: - keys: Dict[str, type] = {} - mro = type.mro(config) - for cls in reversed(mro): - # update in reverse derivation order so derived classes overwrite hints from base classes - if cls is not object: - keys.update( - [(attr, cls.__annotations__.get(attr, None)) - # if hasattr(config, '__annotations__') and attr in config.__annotations__ else None) - for attr in cls.__dict__.keys() if not callable(getattr(cls, attr)) and not attr.startswith("__") - ]) - return keys +def _is_config_bounded(config: TConfiguration, fields: Mapping[str, type]) -> None: + # TODO: here we assume all keys are taken from environ provider, that should change when we introduce more providers + _unbound_attrs = [ + environ.get_key_name(key, config.__namespace__) for key in fields if getattr(config, key) is None and not is_optional_type(fields[key]) + ] + + if len(_unbound_attrs) > 0: + raise ConfigEntryMissingException(_unbound_attrs, config.__namespace__) + + +def _check_configuration_integrity(config: TConfiguration) -> None: + # python multi-inheritance is cooperative and this would require that all configurations cooperatively + # call each other check_integrity. this is not at all possible as we do not know which configs in the end will + # be mixed together. + + # get base classes in order of derivation + mro = type.mro(type(config)) + for c in mro: + # check if this class implements check_integrity (skip pure inheritance to not do double work) + if CHECK_INTEGRITY_F in c.__dict__ and callable(getattr(c, CHECK_INTEGRITY_F)): + # pass right class instance + c.__dict__[CHECK_INTEGRITY_F](config) + + +def _get_resolvable_fields(config: TConfiguration) -> Dict[str, type]: + return {f.name:f.type for f in dataclasses.fields(config) if not f.name.startswith("__")} def _extract_simple_type(hint: Type[Any]) -> Type[Any]: diff --git a/dlt/common/logger.py b/dlt/common/logger.py index 9de24cb809..2c97c7f71b 100644 --- a/dlt/common/logger.py +++ b/dlt/common/logger.py @@ -5,7 +5,7 @@ from sentry_sdk.transport import HttpTransport from sentry_sdk.integrations.logging import LoggingIntegration from logging import LogRecord, Logger -from typing import Any, Type, Protocol +from typing import Any, Protocol from dlt.common.json import json from dlt.common.typing import DictStrAny, StrStr @@ -126,9 +126,9 @@ def wrapper(msg: str, *args: Any, **kwargs: Any) -> None: return wrapper -def _extract_version_info(config: Type[RunConfiguration]) -> StrStr: - version_info = {"version": __version__, "component_name": config.PIPELINE_NAME} - version = getattr(config, "_VERSION", None) +def _extract_version_info(config: RunConfiguration) -> StrStr: + version_info = {"version": __version__, "component_name": config.pipeline_name} + version = getattr(config, "_version", None) if version: version_info["component_version"] = version # extract envs with build info @@ -150,8 +150,8 @@ def _get_pool_options(self, *a: Any, **kw: Any) -> DictStrAny: return rv -def _get_sentry_log_level(C: Type[RunConfiguration]) -> LoggingIntegration: - log_level = logging._nameToLevel[C.LOG_LEVEL] +def _get_sentry_log_level(C: RunConfiguration) -> LoggingIntegration: + log_level = logging._nameToLevel[C.log_level] event_level = logging.WARNING if log_level <= logging.WARNING else log_level return LoggingIntegration( level=logging.INFO, # Capture info and above as breadcrumbs @@ -159,14 +159,14 @@ def _get_sentry_log_level(C: Type[RunConfiguration]) -> LoggingIntegration: ) -def _init_sentry(C: Type[RunConfiguration], version: StrStr) -> None: +def _init_sentry(C: RunConfiguration, version: StrStr) -> None: sys_ver = version["version"] release = sys_ver + "_" + version.get("commit_sha", "") - _SentryHttpTransport.timeout = C.REQUEST_TIMEOUT[0] + _SentryHttpTransport.timeout = C.request_timeout[0] # TODO: ignore certain loggers ie. dbt loggers # https://docs.sentry.io/platforms/python/guides/logging/ sentry_sdk.init( - C.SENTRY_DSN, + C.sentry_dsn, integrations=[_get_sentry_log_level(C)], release=release, transport=_SentryHttpTransport @@ -180,17 +180,17 @@ def _init_sentry(C: Type[RunConfiguration], version: StrStr) -> None: sentry_sdk.set_tag(k, v) -def init_telemetry(config: Type[RunConfiguration]) -> None: - if config.PROMETHEUS_PORT: +def init_telemetry(config: RunConfiguration) -> None: + if config.prometheus_port: from prometheus_client import start_http_server, Info - logging.info(f"Starting prometheus server port {config.PROMETHEUS_PORT}") - start_http_server(config.PROMETHEUS_PORT) + logging.info(f"Starting prometheus server port {config.prometheus_port}") + start_http_server(config.prometheus_port) # collect info Info("runs_component_name", "Name of the executing component").info(_extract_version_info(config)) -def init_logging_from_config(C: Type[RunConfiguration]) -> None: +def init_logging_from_config(C: RunConfiguration) -> None: global LOGGER # add HEALTH and METRICS log levels @@ -201,11 +201,11 @@ def init_logging_from_config(C: Type[RunConfiguration]) -> None: version = _extract_version_info(C) LOGGER = _init_logging( DLT_LOGGER_NAME, - C.LOG_LEVEL, - C.LOG_FORMAT, - C.PIPELINE_NAME, + C.log_level, + C.log_format, + C.pipeline_name, version) - if C.SENTRY_DSN: + if C.sentry_dsn: _init_sentry(C, version) diff --git a/dlt/common/runners/init.py b/dlt/common/runners/init.py index 508702f2bd..1a5025374e 100644 --- a/dlt/common/runners/init.py +++ b/dlt/common/runners/init.py @@ -10,7 +10,7 @@ _INITIALIZED = False -def initialize_runner(C: Type[RunConfiguration]) -> None: +def initialize_runner(C: RunConfiguration) -> None: global _INITIALIZED # initialize or re-initialize logging with new settings diff --git a/dlt/common/runners/pool_runner.py b/dlt/common/runners/pool_runner.py index d5409f974f..be1a271de8 100644 --- a/dlt/common/runners/pool_runner.py +++ b/dlt/common/runners/pool_runner.py @@ -1,6 +1,6 @@ import multiprocessing from prometheus_client import Counter, Gauge, Summary, CollectorRegistry, REGISTRY -from typing import Callable, Dict, Type, Union, cast +from typing import Callable, Dict, Union, cast from multiprocessing.pool import ThreadPool, Pool from dlt.common import logger, signals @@ -39,23 +39,23 @@ def update_gauges() -> TRunHealth: return get_metrics_from_prometheus(HEALTH_PROPS_GAUGES.values()) # type: ignore -def run_pool(C: Type[PoolRunnerConfiguration], run_f: Union[Runnable[TPool], Callable[[TPool], TRunMetrics]]) -> int: +def run_pool(C: PoolRunnerConfiguration, run_f: Union[Runnable[TPool], Callable[[TPool], TRunMetrics]]) -> int: # create health gauges if not HEALTH_PROPS_GAUGES: create_gauges(REGISTRY) # start pool pool: Pool = None - if C.POOL_TYPE == "process": + if C.pool_type == "process": # our pool implementation do not work on spawn if multiprocessing.get_start_method() != "fork": raise UnsupportedProcessStartMethodException(multiprocessing.get_start_method()) - pool = Pool(processes=C.WORKERS) - elif C.POOL_TYPE == "thread": - pool = ThreadPool(processes=C.WORKERS) + pool = Pool(processes=C.workers) + elif C.pool_type == "thread": + pool = ThreadPool(processes=C.workers) else: pool = None - logger.info(f"Created {C.POOL_TYPE} pool with {C.WORKERS or 'default no.'} workers") + logger.info(f"Created {C.pool_type} pool with {C.workers or 'default no.'} workers") # track local stats runs_count = 0 runs_not_idle_count = 0 @@ -88,7 +88,7 @@ def run_pool(C: Type[PoolRunnerConfiguration], run_f: Union[Runnable[TPool], Cal global LAST_RUN_EXCEPTION LAST_RUN_EXCEPTION = exc # re-raise if EXIT_ON_EXCEPTION is requested - if C.EXIT_ON_EXCEPTION: + if C.exit_on_exception: raise finally: if run_metrics: @@ -101,22 +101,22 @@ def run_pool(C: Type[PoolRunnerConfiguration], run_f: Union[Runnable[TPool], Cal # single run may be forced but at least wait_runs must pass # and was all the time idle or (was not idle but now pending is 0) - if C.IS_SINGLE_RUN and (runs_count >= C.WAIT_RUNS and (runs_not_idle_count == 0 or run_metrics.pending_items == 0)): + if C.is_single_run and (runs_count >= C.wait_runs and (runs_not_idle_count == 0 or run_metrics.pending_items == 0)): logger.info("Stopping runner due to single run override") return 0 if run_metrics.has_failed: - sleep(C.RUN_SLEEP_WHEN_FAILED) + sleep(C.run_sleep_when_failed) elif run_metrics.pending_items == 0: # nothing is pending so we can sleep longer - sleep(C.RUN_SLEEP_IDLE) + sleep(C.run_sleep_idle) else: # more items are pending, sleep (typically) shorter - sleep(C.RUN_SLEEP) + sleep(C.run_sleep) # this allows to recycle long living process that get their memory fragmented # exit after runners sleeps so we keep the running period - if runs_count == C.STOP_AFTER_RUNS: + if runs_count == C.stop_after_runs: logger.warning(f"Stopping runner due to max runs {runs_count} exceeded") return 0 except SignalReceivedException as sigex: @@ -129,9 +129,13 @@ def run_pool(C: Type[PoolRunnerConfiguration], run_f: Union[Runnable[TPool], Cal finally: if pool: logger.info("Closing processing pool") - pool.close() - pool.join() + # terminate pool and do not join + pool.terminate() + # in very rare cases process hangs here, even with starmap terminating earlier + # pool.close() + # pool.join() pool = None + logger.info("Closing processing pool closed") def _update_metrics(run_metrics: TRunMetrics) -> TRunHealth: diff --git a/dlt/common/runners/runnable.py b/dlt/common/runners/runnable.py index 053d275f3c..a4e775c0bb 100644 --- a/dlt/common/runners/runnable.py +++ b/dlt/common/runners/runnable.py @@ -9,7 +9,6 @@ from dlt.common.typing import StrAny, TFun from dlt.common.utils import uniq_id from dlt.common.telemetry import TRunMetrics -from dlt.common.configuration.utils import TConfiguration TPool = TypeVar("TPool", bound=Pool) @@ -62,40 +61,40 @@ def _wrap(rid: Union[int, Runnable[TPool]], *args: Any, **kwargs: Any) -> Any: return _wrap # type: ignore -def configuredworker(f: TFun) -> TFun: - """Decorator for a process/thread pool worker function facilitates passing bound configuration type across the process boundary. It requires the first method - of the worker function to be annotated with type derived from Type[BaseConfiguration] and the worker function to be called (typically by the Pool class) with a - configuration values serialized to dict (via `as_dict` method). The decorator will synthesize a new derived type and apply the serialized value, mimicking the - original type to be transferred across the process boundary. - - Args: - f (TFun): worker function to be decorated - - Raises: - ValueError: raised when worker function signature does not contain required parameters or/and annotations - - - Returns: - TFun: wrapped worker function - """ - @wraps(f) - def _wrap(config: Union[StrAny, Type[BaseConfiguration]], *args: Any, **kwargs: Any) -> Any: - if isinstance(config, Mapping): - # worker process may run in separate process started with spawn and should not share any state with the parent process ie. global variables like config - # first function parameter should be of Type[BaseConfiguration] - sig = inspect.signature(f) - try: - first_param: inspect.Parameter = next(iter(sig.parameters.values())) - T = get_args(first_param.annotation)[0] - if not issubclass(T, BaseConfiguration): - raise ValueError(T) - except Exception: - raise ValueError(f"First parameter of wrapped worker method {f.__name__} must by annotated as Type[BaseConfiguration]") - CONFIG = type(f.__name__ + uniq_id(), (T, ), {}) - CONFIG.apply_dict(config) # type: ignore - config = CONFIG - - return f(config, *args, **kwargs) - - return _wrap # type: ignore +# def configuredworker(f: TFun) -> TFun: +# """Decorator for a process/thread pool worker function facilitates passing bound configuration type across the process boundary. It requires the first method +# of the worker function to be annotated with type derived from Type[BaseConfiguration] and the worker function to be called (typically by the Pool class) with a +# configuration values serialized to dict (via `as_dict` method). The decorator will synthesize a new derived type and apply the serialized value, mimicking the +# original type to be transferred across the process boundary. + +# Args: +# f (TFun): worker function to be decorated + +# Raises: +# ValueError: raised when worker function signature does not contain required parameters or/and annotations + + +# Returns: +# TFun: wrapped worker function +# """ +# @wraps(f) +# def _wrap(config: Union[StrAny, Type[BaseConfiguration]], *args: Any, **kwargs: Any) -> Any: +# if isinstance(config, Mapping): +# # worker process may run in separate process started with spawn and should not share any state with the parent process ie. global variables like config +# # first function parameter should be of Type[BaseConfiguration] +# sig = inspect.signature(f) +# try: +# first_param: inspect.Parameter = next(iter(sig.parameters.values())) +# T = get_args(first_param.annotation)[0] +# if not issubclass(T, BaseConfiguration): +# raise ValueError(T) +# except Exception: +# raise ValueError(f"First parameter of wrapped worker method {f.__name__} must by annotated as Type[BaseConfiguration]") +# CONFIG = type(f.__name__ + uniq_id(), (T, ), {}) +# CONFIG.apply_dict(config) # type: ignore +# config = CONFIG + +# return f(config, *args, **kwargs) + +# return _wrap # type: ignore diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index 3c1a131f09..47b6265f28 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -1,4 +1,4 @@ -from typing import Dict, Type +from typing import Dict from dlt.common.configuration import SchemaVolumeConfiguration from dlt.common.schema.schema import Schema @@ -6,7 +6,7 @@ class LiveSchemaStorage(SchemaStorage): - def __init__(self, C: Type[SchemaVolumeConfiguration], makedirs: bool = False) -> None: + def __init__(self, C: SchemaVolumeConfiguration, makedirs: bool = False) -> None: self.live_schemas: Dict[str, Schema] = {} super().__init__(C, makedirs) @@ -36,9 +36,9 @@ def commit_live_schema(self, name: str) -> Schema: if live_schema and live_schema.stored_version_hash != live_schema.version_hash: print("bumping and saving") live_schema.bump_version() - if self.C.IMPORT_SCHEMA_PATH: + if self.C.import_schema_path: # overwrite import schemas if specified - self._export_schema(live_schema, self.C.IMPORT_SCHEMA_PATH) + self._export_schema(live_schema, self.C.import_schema_path) else: # write directly to schema storage if no import schema folder configured self._save_schema(live_schema) diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index c8b139800c..06d0ecfacf 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -1,7 +1,7 @@ import os from os.path import join from pathlib import Path -from typing import Iterable, NamedTuple, Literal, Optional, Sequence, Set, Tuple, Type, get_args +from typing import Iterable, NamedTuple, Literal, Optional, Sequence, Set, get_args from dlt.common import json, pendulum from dlt.common.typing import DictStrAny, StrAny @@ -44,7 +44,7 @@ class LoadStorage(DataItemStorage, VersionedStorage): def __init__( self, is_owner: bool, - C: Type[LoadVolumeConfiguration], + C: LoadVolumeConfiguration, preferred_file_format: TLoaderFileFormat, supported_file_formats: Iterable[TLoaderFileFormat] ) -> None: @@ -53,11 +53,11 @@ def __init__( if preferred_file_format not in supported_file_formats: raise TerminalValueError(preferred_file_format) self.supported_file_formats = supported_file_formats - self.delete_completed_jobs = C.DELETE_COMPLETED_JOBS + self.delete_completed_jobs = C.delete_completed_jobs super().__init__( preferred_file_format, LoadStorage.STORAGE_VERSION, - is_owner, FileStorage(C.LOAD_VOLUME_PATH, "t", makedirs=is_owner) + is_owner, FileStorage(C.load_volume_path, "t", makedirs=is_owner) ) if is_owner: self.initialize_storage() diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 485bc480a1..b0572f8423 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Tuple, Type, NamedTuple +from typing import List, Sequence, NamedTuple from itertools import groupby from pathlib import Path @@ -19,8 +19,8 @@ class NormalizeStorage(VersionedStorage): STORAGE_VERSION = "1.0.0" EXTRACTED_FOLDER: str = "extracted" # folder within the volume where extracted files to be normalized are stored - def __init__(self, is_owner: bool, C: Type[NormalizeVolumeConfiguration]) -> None: - super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(C.NORMALIZE_VOLUME_PATH, "t", makedirs=is_owner)) + def __init__(self, is_owner: bool, C: NormalizeVolumeConfiguration) -> None: + super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(C.normalize_volume_path, "t", makedirs=is_owner)) self.CONFIG = C if is_owner: self.initialize_storage() diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 6ef8beeb32..bfe1912312 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -1,7 +1,7 @@ import os import re import yaml -from typing import Iterator, List, Type, Mapping +from typing import Iterator, List, Mapping from dlt.common import json, logger from dlt.common.configuration.schema_volume_configuration import TSchemaFileFormat @@ -18,9 +18,9 @@ class SchemaStorage(Mapping[str, Schema]): SCHEMA_FILE_NAME = "schema.%s" NAMED_SCHEMA_FILE_PATTERN = f"%s_{SCHEMA_FILE_NAME}" - def __init__(self, C: Type[SchemaVolumeConfiguration], makedirs: bool = False) -> None: + def __init__(self, C: SchemaVolumeConfiguration, makedirs: bool = False) -> None: self.C = C - self.storage = FileStorage(C.SCHEMA_VOLUME_PATH, makedirs=makedirs) + self.storage = FileStorage(C.schema_volume_path, makedirs=makedirs) def load_schema(self, name: str) -> Schema: # loads a schema from a store holding many schemas @@ -30,21 +30,21 @@ def load_schema(self, name: str) -> Schema: storage_schema = json.loads(self.storage.load(schema_file)) # prevent external modifications of schemas kept in storage if not verify_schema_hash(storage_schema, empty_hash_verifies=True): - raise InStorageSchemaModified(name, self.C.SCHEMA_VOLUME_PATH) + raise InStorageSchemaModified(name, self.C.schema_volume_path) except FileNotFoundError: # maybe we can import from external storage pass # try to import from external storage - if self.C.IMPORT_SCHEMA_PATH: + if self.C.import_schema_path: return self._maybe_import_schema(name, storage_schema) if storage_schema is None: - raise SchemaNotFoundError(name, self.C.SCHEMA_VOLUME_PATH) + raise SchemaNotFoundError(name, self.C.schema_volume_path) return Schema.from_dict(storage_schema) def save_schema(self, schema: Schema) -> str: # check if there's schema to import - if self.C.IMPORT_SCHEMA_PATH: + if self.C.import_schema_path: try: imported_schema = Schema.from_dict(self._load_import_schema(schema.name)) # link schema being saved to current imported schema so it will not overwrite this save when loaded @@ -53,8 +53,8 @@ def save_schema(self, schema: Schema) -> str: # just save the schema pass path = self._save_schema(schema) - if self.C.EXPORT_SCHEMA_PATH: - self._export_schema(schema, self.C.EXPORT_SCHEMA_PATH) + if self.C.export_schema_path: + self._export_schema(schema, self.C.export_schema_path) return path def remove_schema(self, name: str) -> None: @@ -111,37 +111,37 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> except FileNotFoundError: # no schema to import -> skip silently and return the original if storage_schema is None: - raise SchemaNotFoundError(name, self.C.SCHEMA_VOLUME_PATH, self.C.IMPORT_SCHEMA_PATH, self.C.EXTERNAL_SCHEMA_FORMAT) + raise SchemaNotFoundError(name, self.C.schema_volume_path, self.C.import_schema_path, self.C.external_schema_format) rv_schema = Schema.from_dict(storage_schema) assert rv_schema is not None return rv_schema def _load_import_schema(self, name: str) -> DictStrAny: - import_storage = FileStorage(self.C.IMPORT_SCHEMA_PATH, makedirs=False) - schema_file = self._file_name_in_store(name, self.C.EXTERNAL_SCHEMA_FORMAT) + import_storage = FileStorage(self.C.import_schema_path, makedirs=False) + schema_file = self._file_name_in_store(name, self.C.external_schema_format) imported_schema: DictStrAny = None imported_schema_s = import_storage.load(schema_file) - if self.C.EXTERNAL_SCHEMA_FORMAT == "json": + if self.C.external_schema_format == "json": imported_schema = json.loads(imported_schema_s) - elif self.C.EXTERNAL_SCHEMA_FORMAT == "yaml": + elif self.C.external_schema_format == "yaml": imported_schema = yaml.safe_load(imported_schema_s) else: - raise ValueError(self.C.EXTERNAL_SCHEMA_FORMAT) + raise ValueError(self.C.external_schema_format) return imported_schema def _export_schema(self, schema: Schema, export_path: str) -> None: - if self.C.EXTERNAL_SCHEMA_FORMAT == "json": - exported_schema_s = schema.to_pretty_json(remove_defaults=self.C.EXTERNAL_SCHEMA_FORMAT_REMOVE_DEFAULTS) - elif self.C.EXTERNAL_SCHEMA_FORMAT == "yaml": - exported_schema_s = schema.to_pretty_yaml(remove_defaults=self.C.EXTERNAL_SCHEMA_FORMAT_REMOVE_DEFAULTS) + if self.C.external_schema_format == "json": + exported_schema_s = schema.to_pretty_json(remove_defaults=self.C.external_schema_format_remove_defaults) + elif self.C.external_schema_format == "yaml": + exported_schema_s = schema.to_pretty_yaml(remove_defaults=self.C.external_schema_format_remove_defaults) else: - raise ValueError(self.C.EXTERNAL_SCHEMA_FORMAT) + raise ValueError(self.C.external_schema_format) export_storage = FileStorage(export_path, makedirs=True) - schema_file = self._file_name_in_store(schema.name, self.C.EXTERNAL_SCHEMA_FORMAT) + schema_file = self._file_name_in_store(schema.name, self.C.external_schema_format) export_storage.save(schema_file, exported_schema_s) - logger.info(f"Schema {schema.name} exported to {export_path} with version {schema.stored_version} as {self.C.EXTERNAL_SCHEMA_FORMAT}") + logger.info(f"Schema {schema.name} exported to {export_path} with version {schema.stored_version} as {self.C.external_schema_format}") def _save_schema(self, schema: Schema) -> str: # save a schema to schema store diff --git a/dlt/common/typing.py b/dlt/common/typing.py index e22f2ae21a..69a0ea941f 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,6 +1,6 @@ from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from re import Pattern as _REPattern -from typing import Callable, Dict, Any, Literal, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin +from typing import Callable, Dict, Any, Literal, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, Iterable, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin if TYPE_CHECKING: from _typeshed import StrOrBytesPath from typing import _TypedDict @@ -25,7 +25,6 @@ TVariantRV = Tuple[str, Any] VARIANT_FIELD_FORMAT = "v_%s" - @runtime_checkable class SupportsVariant(Protocol, Generic[TVariantBase]): """Defines variant type protocol that should be recognized by normalizers diff --git a/dlt/common/utils.py b/dlt/common/utils.py index d8bc2d3491..03ca293917 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -1,10 +1,13 @@ -from functools import wraps import os +from pathlib import Path +import sys import base64 -from contextlib import contextmanager import hashlib -from os import environ import secrets +from contextlib import contextmanager +from functools import wraps +from os import environ + from typing import Any, Dict, Iterator, Optional, Sequence, TypeVar, Mapping, List, TypedDict, Union from dlt.common.typing import StrAny, DictStrAny, StrStr, TFun @@ -169,3 +172,9 @@ def encoding_for_mode(mode: str) -> Optional[str]: return None else: return "utf-8" + + +def entry_point_file_stem() -> str: + if len(sys.argv) > 0 and os.path.isfile(sys.argv[0]): + return Path(sys.argv[0]).stem + return None diff --git a/dlt/dbt_runner/configuration.py b/dlt/dbt_runner/configuration.py index f9279c20b1..947a5e5f15 100644 --- a/dlt/dbt_runner/configuration.py +++ b/dlt/dbt_runner/configuration.py @@ -1,70 +1,56 @@ +import dataclasses from typing import List, Optional, Type from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration import make_configuration from dlt.common.configuration.providers import environ -from dlt.common.configuration import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials +from dlt.common.configuration import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials, make_configuration, configspec from . import __version__ +@configspec class DBTRunnerConfiguration(PoolRunnerConfiguration): - POOL_TYPE: TPoolType = "none" - STOP_AFTER_RUNS: int = 1 - PACKAGE_VOLUME_PATH: str = "_storage/dbt_runner" - PACKAGE_REPOSITORY_URL: str = "https://github.com/scale-vector/rasa_semantic_schema_customization.git" - PACKAGE_REPOSITORY_BRANCH: Optional[str] = None - PACKAGE_REPOSITORY_SSH_KEY: TSecretValue = TSecretValue("") # the default is empty value which will disable custom SSH KEY - PACKAGE_PROFILES_DIR: str = "." - PACKAGE_PROFILE_PREFIX: str = "rasa_semantic_schema" - PACKAGE_SOURCE_TESTS_SELECTOR: str = "tag:prerequisites" - PACKAGE_ADDITIONAL_VARS: Optional[StrAny] = None - PACKAGE_RUN_PARAMS: List[str] = ["--fail-fast"] - AUTO_FULL_REFRESH_WHEN_OUT_OF_SYNC: bool = True - - SOURCE_SCHEMA_PREFIX: str = None - DEST_SCHEMA_PREFIX: Optional[str] = None - - @classmethod - def check_integrity(cls) -> None: - if cls.PACKAGE_REPOSITORY_SSH_KEY and cls.PACKAGE_REPOSITORY_SSH_KEY[-1] != "\n": + pool_type: TPoolType = "none" + stop_after_runs: int = 1 + package_volume_path: str = "/var/local/app" + package_repository_url: str = "https://github.com/scale-vector/rasa_semantic_schema_customization.git" + package_repository_branch: Optional[str] = None + package_repository_ssh_key: TSecretValue = TSecretValue("") # the default is empty value which will disable custom SSH KEY + package_profiles_dir: str = "." + package_profile_prefix: str = "rasa_semantic_schema" + package_source_tests_selector: str = "tag:prerequisites" + package_additional_vars: Optional[StrAny] = None + package_run_params: List[str] = dataclasses.field(default_factory=lambda: ["--fail-fast"]) + auto_full_refresh_when_out_of_sync: bool = True + + source_schema_prefix: str = None + dest_schema_prefix: Optional[str] = None + + def check_integrity(self) -> None: + if self.package_repository_ssh_key and self.package_repository_ssh_key[-1] != "\n": # must end with new line, otherwise won't be parsed by Crypto - cls.PACKAGE_REPOSITORY_SSH_KEY = TSecretValue(cls.PACKAGE_REPOSITORY_SSH_KEY + "\n") - if cls.STOP_AFTER_RUNS != 1: + self.package_repository_ssh_key = TSecretValue(self.package_repository_ssh_key + "\n") + if self.stop_after_runs != 1: # always stop after one run - cls.STOP_AFTER_RUNS = 1 - + self.stop_after_runs = 1 -class DBTRunnerProductionConfiguration(DBTRunnerConfiguration): - PACKAGE_VOLUME_PATH: str = "/var/local/app" # this is actually not exposed as volume - PACKAGE_REPOSITORY_URL: str = None - -def gen_configuration_variant(initial_values: StrAny = None) -> Type[DBTRunnerConfiguration]: +def gen_configuration_variant(initial_values: StrAny = None) -> DBTRunnerConfiguration: # derive concrete config depending on env vars present DBTRunnerConfigurationImpl: Type[DBTRunnerConfiguration] - DBTRunnerProductionConfigurationImpl: Type[DBTRunnerProductionConfiguration] - source_schema_prefix = environ.get_key("DEFAULT_DATASET", type(str)) + source_schema_prefix = environ.get_key("default_dataset", type(str)) - if environ.get_key("PROJECT_ID", type(str), namespace=GcpClientCredentials.__namespace__): + if environ.get_key("project_id", type(str), namespace=GcpClientCredentials.__namespace__): + @configspec class DBTRunnerConfigurationPostgres(PostgresCredentials, DBTRunnerConfiguration): SOURCE_SCHEMA_PREFIX: str = source_schema_prefix DBTRunnerConfigurationImpl = DBTRunnerConfigurationPostgres - class DBTRunnerProductionConfigurationPostgres(DBTRunnerProductionConfiguration, DBTRunnerConfigurationPostgres): - pass - # SOURCE_SCHEMA_PREFIX: str = source_schema_prefix - DBTRunnerProductionConfigurationImpl = DBTRunnerProductionConfigurationPostgres - else: + @configspec class DBTRunnerConfigurationGcp(GcpClientCredentials, DBTRunnerConfiguration): SOURCE_SCHEMA_PREFIX: str = source_schema_prefix DBTRunnerConfigurationImpl = DBTRunnerConfigurationGcp - class DBTRunnerProductionConfigurationGcp(DBTRunnerProductionConfiguration, DBTRunnerConfigurationGcp): - pass - # SOURCE_SCHEMA_PREFIX: str = source_schema_prefix - DBTRunnerProductionConfigurationImpl = DBTRunnerProductionConfigurationGcp - - return make_configuration(DBTRunnerConfigurationImpl, DBTRunnerProductionConfigurationImpl, initial_values=initial_values) + return make_configuration(DBTRunnerConfigurationImpl(), initial_value=initial_values) diff --git a/dlt/dbt_runner/runner.py b/dlt/dbt_runner/runner.py index ac186ab59c..af4cf92825 100644 --- a/dlt/dbt_runner/runner.py +++ b/dlt/dbt_runner/runner.py @@ -1,13 +1,13 @@ -from typing import Optional, Sequence, Tuple, Type +from typing import Optional, Sequence, Tuple from git import GitError from prometheus_client import REGISTRY, Gauge, CollectorRegistry, Info from prometheus_client.metrics import MetricWrapperBase -from dlt.common.configuration import GcpClientCredentials from dlt.common import logger from dlt.common.typing import DictStrAny, DictStrStr, StrAny from dlt.common.logger import is_json_logging from dlt.common.telemetry import get_logging_extras +from dlt.common.configuration import GcpClientCredentials from dlt.common.file_storage import FileStorage from dlt.cli import TRunnerArgs from dlt.common.runners import initialize_runner, run_pool @@ -20,7 +20,7 @@ CLONED_PACKAGE_NAME = "dbt_package" -CONFIG: Type[DBTRunnerConfiguration] = None +CONFIG: DBTRunnerConfiguration = None storage: FileStorage = None dbt_package_vars: StrAny = None global_args: Sequence[str] = None @@ -32,28 +32,28 @@ def create_folders() -> Tuple[FileStorage, StrAny, Sequence[str], str, str]: - storage = FileStorage(CONFIG.PACKAGE_VOLUME_PATH, makedirs=True) + storage = FileStorage(CONFIG.package_volume_path, makedirs=True) dbt_package_vars: DictStrAny = { - "source_schema_prefix": CONFIG.SOURCE_SCHEMA_PREFIX + "source_schema_prefix": CONFIG.source_schema_prefix } - if CONFIG.DEST_SCHEMA_PREFIX: - dbt_package_vars["dest_schema_prefix"] = CONFIG.DEST_SCHEMA_PREFIX - if CONFIG.PACKAGE_ADDITIONAL_VARS: - dbt_package_vars.update(CONFIG.PACKAGE_ADDITIONAL_VARS) + if CONFIG.dest_schema_prefix: + dbt_package_vars["dest_schema_prefix"] = CONFIG.dest_schema_prefix + if CONFIG.package_additional_vars: + dbt_package_vars.update(CONFIG.package_additional_vars) # initialize dbt logging, returns global parameters to dbt command - global_args = initialize_dbt_logging(CONFIG.LOG_LEVEL, is_json_logging(CONFIG.LOG_FORMAT)) + global_args = initialize_dbt_logging(CONFIG.log_level, is_json_logging(CONFIG.log_format)) # generate path for the dbt package repo repo_path = storage.make_full_path(CLONED_PACKAGE_NAME) # generate profile name profile_name: str = None - if CONFIG.PACKAGE_PROFILE_PREFIX: - if issubclass(CONFIG, GcpClientCredentials): - profile_name = "%s_bigquery" % (CONFIG.PACKAGE_PROFILE_PREFIX) + if CONFIG.package_profile_prefix: + if isinstance(CONFIG, GcpClientCredentials): + profile_name = "%s_bigquery" % (CONFIG.package_profile_prefix) else: - profile_name = "%s_redshift" % (CONFIG.PACKAGE_PROFILE_PREFIX) + profile_name = "%s_redshift" % (CONFIG.package_profile_prefix) return storage, dbt_package_vars, global_args, repo_path, profile_name @@ -69,7 +69,7 @@ def run_dbt(command: str, command_args: Sequence[str] = None) -> Sequence[dbt_re logger.info(f"Exec dbt command: {global_args} {command} {command_args} {dbt_package_vars} on profile {profile_name or ''}") return run_dbt_command( repo_path, command, - CONFIG.PACKAGE_PROFILES_DIR, + CONFIG.package_profiles_dir, profile_name=profile_name, command_args=command_args, global_args=global_args, @@ -109,8 +109,8 @@ def initialize_package(with_git_command: Optional[str]) -> None: # cleanup package folder if storage.has_folder(CLONED_PACKAGE_NAME): storage.delete_folder(CLONED_PACKAGE_NAME, recursively=True) - logger.info(f"Will clone {CONFIG.PACKAGE_REPOSITORY_URL} head {CONFIG.PACKAGE_REPOSITORY_BRANCH} into {repo_path}") - clone_repo(CONFIG.PACKAGE_REPOSITORY_URL, repo_path, branch=CONFIG.PACKAGE_REPOSITORY_BRANCH, with_git_command=with_git_command) + logger.info(f"Will clone {CONFIG.package_repository_url} head {CONFIG.package_repository_branch} into {repo_path}") + clone_repo(CONFIG.package_repository_url, repo_path, branch=CONFIG.package_repository_branch, with_git_command=with_git_command) run_dbt("deps") except Exception: # delete folder so we start clean next time @@ -120,7 +120,7 @@ def initialize_package(with_git_command: Optional[str]) -> None: def ensure_newest_package() -> None: - with git_custom_key_command(CONFIG.PACKAGE_REPOSITORY_SSH_KEY) as ssh_command: + with git_custom_key_command(CONFIG.package_repository_ssh_key) as ssh_command: try: ensure_remote_head(repo_path, with_git_command=ssh_command) except GitError as err: @@ -134,8 +134,8 @@ def run_db_steps() -> Sequence[dbt_results.BaseResult]: ensure_newest_package() # check if raw schema exists try: - if CONFIG.PACKAGE_SOURCE_TESTS_SELECTOR: - run_dbt("test", ["-s", CONFIG.PACKAGE_SOURCE_TESTS_SELECTOR]) + if CONFIG.package_source_tests_selector: + run_dbt("test", ["-s", CONFIG.package_source_tests_selector]) except DBTProcessingError as err: raise PrerequisitesException() from err @@ -143,12 +143,12 @@ def run_db_steps() -> Sequence[dbt_results.BaseResult]: run_dbt("seed") # throws DBTProcessingError try: - return run_dbt("run", CONFIG.PACKAGE_RUN_PARAMS) + return run_dbt("run", CONFIG.package_run_params) except DBTProcessingError as e: # detect incremental model out of sync - if is_incremental_schema_out_of_sync_error(e.results) and CONFIG.AUTO_FULL_REFRESH_WHEN_OUT_OF_SYNC: + if is_incremental_schema_out_of_sync_error(e.results) and CONFIG.auto_full_refresh_when_out_of_sync: logger.warning(f"Attempting full refresh due to incremental model out of sync on {e.results.message}") - return run_dbt("run", CONFIG.PACKAGE_RUN_PARAMS + ["--full-refresh"]) + return run_dbt("run", CONFIG.package_run_params + ["--full-refresh"]) else: raise @@ -172,7 +172,7 @@ def run(_: None) -> TRunMetrics: raise -def configure(C: Type[DBTRunnerConfiguration], collector: CollectorRegistry) -> None: +def configure(C: DBTRunnerConfiguration, collector: CollectorRegistry) -> None: global CONFIG global storage, dbt_package_vars, global_args, repo_path, profile_name global model_elapsed_gauge, model_exec_info @@ -183,7 +183,7 @@ def configure(C: Type[DBTRunnerConfiguration], collector: CollectorRegistry) -> model_elapsed_gauge, model_exec_info = create_gauges(REGISTRY) except ValueError as v: # ignore re-creation of gauges - if "Duplicated timeseries" not in str(v): + if "Duplicated time-series" not in str(v): raise diff --git a/dlt/helpers/streamlit.py b/dlt/helpers/streamlit.py index ed1c3f624e..fc66fc3988 100644 --- a/dlt/helpers/streamlit.py +++ b/dlt/helpers/streamlit.py @@ -9,7 +9,7 @@ from dlt.pipeline.typing import credentials_from_dict from dlt.pipeline.exceptions import MissingDependencyException, PipelineException from dlt.helpers.pandas import query_results_to_df, pd -from dlt.common.configuration.run_configuration import BaseConfiguration, CredentialsConfiguration +from dlt.common.configuration.base_configuration import BaseConfiguration, CredentialsConfiguration from dlt.common.utils import dict_remove_nones_in_place try: @@ -35,8 +35,8 @@ def restore_pipeline() -> Pipeline: raise PipelineException("You must backup pipeline to Streamlit first") dlt_cfg = secrets["dlt"] credentials = deepcopy(dict(dlt_cfg["destination"])) - if "DEFAULT_SCHEMA_NAME" in credentials: - del credentials["DEFAULT_SCHEMA_NAME"] + if "default_schema_name" in credentials: + del credentials["default_schema_name"] credentials.update(dlt_cfg["credentials"]) pipeline = Pipeline(dlt_cfg["pipeline_name"]) pipeline.restore_pipeline(credentials_from_dict(credentials), dlt_cfg["working_dir"]) @@ -77,8 +77,8 @@ def backup_pipeline(pipeline: Pipeline) -> None: # save client config # print(dict_remove_nones_in_place(CONFIG.as_dict(lowercase=False))) dlt_c = cast(TomlContainer, secrets["dlt"]) - dlt_c["destination"] = dict_remove_nones_in_place(CONFIG.as_dict(lowercase=False)) - dlt_c["credentials"] = dict_remove_nones_in_place(CREDENTIALS.as_dict(lowercase=False)) + dlt_c["destination"] = dict_remove_nones_in_place(dict(CONFIG)) + dlt_c["credentials"] = dict_remove_nones_in_place(dict(CREDENTIALS)) with open(SECRETS_FILE_LOC, "w", encoding="utf-8") as f: # use whitespace preserving parser diff --git a/dlt/load/bigquery/client.py b/dlt/load/bigquery/client.py index d1a01e8cec..20dfe7b4df 100644 --- a/dlt/load/bigquery/client.py +++ b/dlt/load/bigquery/client.py @@ -1,6 +1,6 @@ from pathlib import Path from contextlib import contextmanager -from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple, Type +from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple import google.cloud.bigquery as bigquery # noqa: I250 from google.cloud.bigquery.dbapi import Connection as DbApiConnection from google.cloud import exceptions as gcp_exceptions @@ -52,12 +52,12 @@ class BigQuerySqlClient(SqlClientBase[bigquery.Client]): - def __init__(self, default_dataset_name: str, CREDENTIALS: Type[GcpClientCredentials]) -> None: + def __init__(self, default_dataset_name: str, CREDENTIALS: GcpClientCredentials) -> None: self._client: bigquery.Client = None self.C = CREDENTIALS super().__init__(default_dataset_name) - self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(CREDENTIALS.RETRY_DEADLINE) + self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(CREDENTIALS.retry_deadline) self.default_query = bigquery.QueryJobConfig(default_dataset=self.fully_qualified_dataset_name()) def open_connection(self) -> None: @@ -65,8 +65,8 @@ def open_connection(self) -> None: if self.C.__is_partial__: credentials = None else: - credentials = service_account.Credentials.from_service_account_info(self.C.as_credentials()) - self._client = bigquery.Client(self.C.PROJECT_ID, credentials=credentials, location=self.C.LOCATION) + credentials = service_account.Credentials.from_service_account_info(self.C.to_native_representation()) + self._client = bigquery.Client(self.C.project_id, credentials=credentials, location=self.C.location) def close_connection(self) -> None: if self._client: @@ -79,7 +79,7 @@ def native_connection(self) -> bigquery.Client: def has_dataset(self) -> bool: try: - self._client.get_dataset(self.fully_qualified_dataset_name(), retry=self.default_retry, timeout=self.C.HTTP_TIMEOUT) + self._client.get_dataset(self.fully_qualified_dataset_name(), retry=self.default_retry, timeout=self.C.http_timeout) return True except gcp_exceptions.NotFound: return False @@ -89,7 +89,7 @@ def create_dataset(self) -> None: self.fully_qualified_dataset_name(), exists_ok=False, retry=self.default_retry, - timeout=self.C.HTTP_TIMEOUT + timeout=self.C.http_timeout ) def drop_dataset(self) -> None: @@ -98,7 +98,7 @@ def drop_dataset(self) -> None: not_found_ok=True, delete_contents=True, retry=self.default_retry, - timeout=self.C.HTTP_TIMEOUT + timeout=self.C.http_timeout ) def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: @@ -106,7 +106,7 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen def_kwargs = { "job_config": self.default_query, "job_retry": self.default_retry, - "timeout": self.C.HTTP_TIMEOUT + "timeout": self.C.http_timeout } kwargs = {**def_kwargs, **(kwargs or {})} results = self._client.query(sql, *args, **kwargs).result() @@ -135,19 +135,19 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[D conn.close() def fully_qualified_dataset_name(self) -> str: - return f"{self.C.PROJECT_ID}.{self.default_dataset_name}" + return f"{self.C.project_id}.{self.default_dataset_name}" class BigQueryLoadJob(LoadJob): - def __init__(self, file_name: str, bq_load_job: bigquery.LoadJob, CONFIG: Type[GcpClientCredentials]) -> None: + def __init__(self, file_name: str, bq_load_job: bigquery.LoadJob, CONFIG: GcpClientCredentials) -> None: self.bq_load_job = bq_load_job self.C = CONFIG - self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(CONFIG.RETRY_DEADLINE) + self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(CONFIG.retry_deadline) super().__init__(file_name) def status(self) -> LoadJobStatus: # check server if done - done = self.bq_load_job.done(retry=self.default_retry, timeout=self.C.HTTP_TIMEOUT) + done = self.bq_load_job.done(retry=self.default_retry, timeout=self.C.http_timeout) if done: # rows processed if self.bq_load_job.output_rows is not None and self.bq_load_job.error_result is None: @@ -183,12 +183,12 @@ def exception(self) -> str: class BigQueryClient(SqlJobClientBase): - CONFIG: Type[BigQueryClientConfiguration] = None - CREDENTIALS: Type[GcpClientCredentials] = None + CONFIG: BigQueryClientConfiguration = None + CREDENTIALS: GcpClientCredentials = None def __init__(self, schema: Schema) -> None: sql_client = BigQuerySqlClient( - schema.normalize_make_dataset_name(self.CONFIG.DEFAULT_DATASET, self.CONFIG.DEFAULT_SCHEMA_NAME, schema.name), + schema.normalize_make_dataset_name(self.CONFIG.default_dataset, self.CONFIG.default_schema_name, schema.name), self.CREDENTIALS ) super().__init__(schema, sql_client) @@ -296,7 +296,7 @@ def _get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns schema_table: TTableSchemaColumns = {} try: table = self.sql_client.native_connection.get_table( - self.sql_client.make_qualified_table_name(table_name), retry=self.sql_client.default_retry, timeout=self.CREDENTIALS.HTTP_TIMEOUT + self.sql_client.make_qualified_table_name(table_name), retry=self.sql_client.default_retry, timeout=self.CREDENTIALS.http_timeout ) partition_field = table.time_partitioning.field if table.time_partitioning else None for c in table.schema: @@ -333,7 +333,7 @@ def _create_load_job(self, table_name: str, write_disposition: TWriteDisposition self.sql_client.make_qualified_table_name(table_name), job_id=job_id, job_config=job_config, - timeout=self.CREDENTIALS.HTTP_TIMEOUT + timeout=self.CREDENTIALS.http_timeout ) def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: @@ -380,7 +380,7 @@ def capabilities(cls) -> TLoaderCapabilities: } @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[Type[BigQueryClientConfiguration], Type[GcpClientCredentials]]: + def configure(cls, initial_values: StrAny = None) -> Tuple[BigQueryClientConfiguration, GcpClientCredentials]: cls.CONFIG, cls.CREDENTIALS = configuration(initial_values=initial_values) return cls.CONFIG, cls.CREDENTIALS diff --git a/dlt/load/bigquery/configuration.py b/dlt/load/bigquery/configuration.py index 8bca223db3..b4c8d248e7 100644 --- a/dlt/load/bigquery/configuration.py +++ b/dlt/load/bigquery/configuration.py @@ -1,37 +1,38 @@ -from typing import Tuple, Type +from typing import Tuple from google.auth import default as default_credentials from google.auth.exceptions import DefaultCredentialsError from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration, GcpClientCredentials +from dlt.common.configuration import make_configuration, GcpClientCredentials, configspec from dlt.common.configuration.exceptions import ConfigEntryMissingException from dlt.load.configuration import LoaderClientDwhConfiguration +@configspec class BigQueryClientConfiguration(LoaderClientDwhConfiguration): - CLIENT_TYPE: str = "bigquery" + client_type: str = "bigquery" -def configuration(initial_values: StrAny = None) -> Tuple[Type[BigQueryClientConfiguration], Type[GcpClientCredentials]]: +def configuration(initial_values: StrAny = None) -> Tuple[BigQueryClientConfiguration, GcpClientCredentials]: - def maybe_partial_credentials() -> Type[GcpClientCredentials]: + def maybe_partial_credentials() -> GcpClientCredentials: try: - return make_configuration(GcpClientCredentials, GcpClientCredentials, initial_values=initial_values) + return make_configuration(GcpClientCredentials(), initial_value=initial_values) except ConfigEntryMissingException as cfex: # if config is missing check if credentials can be obtained from defaults try: _, project_id = default_credentials() # if so then return partial so we can access timeouts - C_PARTIAL = make_configuration(GcpClientCredentials, GcpClientCredentials, initial_values=initial_values, accept_partial = True) + C_PARTIAL = make_configuration(GcpClientCredentials(), initial_value=initial_values, accept_partial = True) # set the project id - it needs to be known by the client - C_PARTIAL.PROJECT_ID = C_PARTIAL.PROJECT_ID or project_id + C_PARTIAL.project_id = C_PARTIAL.project_id or project_id return C_PARTIAL except DefaultCredentialsError: raise cfex return ( - make_configuration(BigQueryClientConfiguration, BigQueryClientConfiguration, initial_values=initial_values), + make_configuration(BigQueryClientConfiguration(), initial_value=initial_values), # allow partial credentials so the client can fallback to default credentials maybe_partial_credentials() ) diff --git a/dlt/load/client_base.py b/dlt/load/client_base.py index 92a92faa2c..1d07a30f7d 100644 --- a/dlt/load/client_base.py +++ b/dlt/load/client_base.py @@ -111,7 +111,7 @@ def capabilities(cls) -> TLoaderCapabilities: @classmethod @abstractmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[Type[BaseConfiguration], Type[CredentialsConfiguration]]: + def configure(cls, initial_values: StrAny = None) -> Tuple[BaseConfiguration, CredentialsConfiguration]: pass diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index bd5da8935c..f502a48379 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -1,32 +1,26 @@ -from typing import Any, Optional, Type -from dlt.common.configuration.run_configuration import BaseConfiguration +from typing import Optional from dlt.common.typing import StrAny -from dlt.common.configuration import (PoolRunnerConfiguration, - LoadVolumeConfiguration, - ProductionLoadVolumeConfiguration, - TPoolType, make_configuration) -from . import __version__ +from dlt.common.configuration import BaseConfiguration, PoolRunnerConfiguration, LoadVolumeConfiguration, TPoolType, make_configuration, configspec +from . import __version__ +@configspec class LoaderClientConfiguration(BaseConfiguration): - CLIENT_TYPE: str = None # which destination to load data to + client_type: str = None # which destination to load data to +@configspec class LoaderClientDwhConfiguration(LoaderClientConfiguration): - DEFAULT_DATASET: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix - DEFAULT_SCHEMA_NAME: Optional[str] = None # name of default schema to be used to name effective dataset to load data to + default_dataset: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix + default_schema_name: Optional[str] = None # name of default schema to be used to name effective dataset to load data to +@configspec class LoaderConfiguration(PoolRunnerConfiguration, LoadVolumeConfiguration, LoaderClientConfiguration): - WORKERS: int = 20 # how many parallel loads can be executed - # MAX_PARALLELISM: int = 20 # in 20 separate threads - POOL_TYPE: TPoolType = "thread" # mostly i/o (upload) so may be thread pool - - -class ProductionLoaderConfiguration(ProductionLoadVolumeConfiguration, LoaderConfiguration): - pass + workers: int = 20 # how many parallel loads can be executed + pool_type: TPoolType = "thread" # mostly i/o (upload) so may be thread pool -def configuration(initial_values: StrAny = None) -> Type[LoaderConfiguration]: - return make_configuration(LoaderConfiguration, ProductionLoaderConfiguration, initial_values=initial_values) +def configuration(initial_values: StrAny = None) -> LoaderConfiguration: + return make_configuration(LoaderConfiguration(), initial_value=initial_values) diff --git a/dlt/load/dummy/client.py b/dlt/load/dummy/client.py index 76b724001e..6bb85c62cd 100644 --- a/dlt/load/dummy/client.py +++ b/dlt/load/dummy/client.py @@ -78,7 +78,7 @@ class DummyClient(JobClientBase): """ dummy client storing jobs in memory """ - CONFIG: Type[DummyClientConfiguration] = None + CONFIG: DummyClientConfiguration = None def __init__(self, schema: Schema) -> None: super().__init__(schema) @@ -120,17 +120,17 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb def _create_job(self, job_id: str) -> LoadDummyJob: return LoadDummyJob( job_id, - fail_prob=self.CONFIG.FAIL_PROB, - retry_prob=self.CONFIG.RETRY_PROB, - completed_prob=self.CONFIG.COMPLETED_PROB, - timeout=self.CONFIG.TIMEOUT + fail_prob=self.CONFIG.fail_prob, + retry_prob=self.CONFIG.retry_prob, + completed_prob=self.CONFIG.completed_prob, + timeout=self.CONFIG.timeout ) @classmethod def capabilities(cls) -> TLoaderCapabilities: return { - "preferred_loader_file_format": cls.CONFIG.LOADER_FILE_FORMAT, - "supported_loader_file_formats": [cls.CONFIG.LOADER_FILE_FORMAT], + "preferred_loader_file_format": cls.CONFIG.loader_file_format, + "supported_loader_file_formats": [cls.CONFIG.loader_file_format], "max_identifier_length": 127, "max_column_length": 127, "max_query_length": 8 * 1024 * 1024, @@ -140,7 +140,7 @@ def capabilities(cls) -> TLoaderCapabilities: } @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[Type[DummyClientConfiguration], Type[CredentialsConfiguration]]: + def configure(cls, initial_values: StrAny = None) -> Tuple[DummyClientConfiguration, CredentialsConfiguration]: cls.CONFIG = configuration(initial_values=initial_values) return cls.CONFIG, None diff --git a/dlt/load/dummy/configuration.py b/dlt/load/dummy/configuration.py index 87eef77817..93d9907258 100644 --- a/dlt/load/dummy/configuration.py +++ b/dlt/load/dummy/configuration.py @@ -1,20 +1,19 @@ -from typing import Type - from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration +from dlt.common.configuration import make_configuration, configspec from dlt.common.data_writers import TLoaderFileFormat from dlt.load.configuration import LoaderClientConfiguration +@configspec class DummyClientConfiguration(LoaderClientConfiguration): - CLIENT_TYPE: str = "dummy" - LOADER_FILE_FORMAT: TLoaderFileFormat = "jsonl" - FAIL_PROB: float = 0.0 - RETRY_PROB: float = 0.0 - COMPLETED_PROB: float = 0.0 - TIMEOUT: float = 10.0 + client_type: str = "dummy" + loader_file_format: TLoaderFileFormat = "jsonl" + fail_prob: float = 0.0 + retry_prob: float = 0.0 + completed_prob: float = 0.0 + timeout: float = 10.0 -def configuration(initial_values: StrAny = None) -> Type[DummyClientConfiguration]: - return make_configuration(DummyClientConfiguration, DummyClientConfiguration, initial_values=initial_values) +def configuration(initial_values: StrAny = None) -> DummyClientConfiguration: + return make_configuration(DummyClientConfiguration(), initial_value=initial_values) diff --git a/dlt/load/load.py b/dlt/load/load.py index 54ed8fa4f1..ddcec05682 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -31,9 +31,9 @@ class Load(Runnable[ThreadPool]): job_counter: Counter = None job_wait_summary: Summary = None - def __init__(self, C: Type[LoaderConfiguration], collector: CollectorRegistry, client_initial_values: StrAny = None, is_storage_owner: bool = False) -> None: + def __init__(self, C: LoaderConfiguration, collector: CollectorRegistry, client_initial_values: StrAny = None, is_storage_owner: bool = False) -> None: self.CONFIG = C - self.load_client_cls = self.import_client_cls(C.CLIENT_TYPE, initial_values=client_initial_values) + self.load_client_cls = self.import_client_cls(C.client_type, initial_values=client_initial_values) self.pool: ThreadPool = None self.load_storage: LoadStorage = self.create_storage(is_storage_owner) try: @@ -113,7 +113,7 @@ def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJo # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs # TODO: combine files by providing a list of files pertaining to same table into job, so job must be # extended to accept a list - load_files = self.load_storage.list_new_jobs(load_id)[:self.CONFIG.WORKERS] + load_files = self.load_storage.list_new_jobs(load_id)[:self.CONFIG.workers] file_count = len(load_files) if file_count == 0: logger.info(f"No new jobs found in {load_id}") @@ -208,11 +208,11 @@ def run(self, pool: ThreadPool) -> TRunMetrics: logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") # initialize analytical storage ie. create dataset required by passed schema with self.load_client_cls(schema) as client: - logger.info(f"Client {self.CONFIG.CLIENT_TYPE} will start load") + logger.info(f"Client {self.CONFIG.client_type} will start load") client.initialize_storage() schema_update = self.load_storage.begin_schema_update(load_id) if schema_update: - logger.info(f"Client {self.CONFIG.CLIENT_TYPE} will update schema to package schema") + logger.info(f"Client {self.CONFIG.client_type} will update schema to package schema") # TODO: this should rather generate an SQL job(s) to be executed PRE loading client.update_storage_schema() self.load_storage.commit_schema_update(load_id) diff --git a/dlt/load/redshift/client.py b/dlt/load/redshift/client.py index 764c04ea3b..889615c241 100644 --- a/dlt/load/redshift/client.py +++ b/dlt/load/redshift/client.py @@ -9,7 +9,7 @@ from contextlib import contextmanager -from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple, Type +from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple from dlt.common.configuration.postgres_credentials import PostgresCredentials from dlt.common.typing import StrAny @@ -57,14 +57,14 @@ class RedshiftSqlClient(SqlClientBase["psycopg2.connection"]): - def __init__(self, default_dataset_name: str, CREDENTIALS: Type[PostgresCredentials]) -> None: + def __init__(self, default_dataset_name: str, CREDENTIALS: PostgresCredentials) -> None: super().__init__(default_dataset_name) self._conn: psycopg2.connection = None self.C = CREDENTIALS def open_connection(self) -> None: self._conn = psycopg2.connect( - **self.C.as_dict(), + **self.C, options=f"-c search_path={self.fully_qualified_dataset_name()},public" ) # we'll provide explicit transactions @@ -200,12 +200,12 @@ def _insert(self, qualified_table_name: str, write_disposition: TWriteDispositio class RedshiftClient(SqlJobClientBase): - CONFIG: Type[RedshiftClientConfiguration] = None - CREDENTIALS: Type[PostgresCredentials] = None + CONFIG: RedshiftClientConfiguration = None + CREDENTIALS: PostgresCredentials = None def __init__(self, schema: Schema) -> None: sql_client = RedshiftSqlClient( - schema.normalize_make_dataset_name(self.CONFIG.DEFAULT_DATASET, self.CONFIG.DEFAULT_SCHEMA_NAME, schema.name), + schema.normalize_make_dataset_name(self.CONFIG.default_dataset, self.CONFIG.default_schema_name, schema.name), self.CREDENTIALS ) super().__init__(schema, sql_client) @@ -347,7 +347,7 @@ def capabilities(cls) -> TLoaderCapabilities: } @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[Type[RedshiftClientConfiguration], Type[PostgresCredentials]]: + def configure(cls, initial_values: StrAny = None) -> Tuple[RedshiftClientConfiguration, PostgresCredentials]: cls.CONFIG, cls.CREDENTIALS = configuration(initial_values=initial_values) return cls.CONFIG, cls.CREDENTIALS diff --git a/dlt/load/redshift/configuration.py b/dlt/load/redshift/configuration.py index cd883ed885..c17c39292e 100644 --- a/dlt/load/redshift/configuration.py +++ b/dlt/load/redshift/configuration.py @@ -1,17 +1,18 @@ -from typing import Tuple, Type +from typing import Tuple from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration, PostgresCredentials +from dlt.common.configuration import configspec, make_configuration, PostgresCredentials from dlt.load.configuration import LoaderClientDwhConfiguration +@configspec class RedshiftClientConfiguration(LoaderClientDwhConfiguration): - CLIENT_TYPE: str = "redshift" + client_type: str = "redshift" -def configuration(initial_values: StrAny = None) -> Tuple[Type[RedshiftClientConfiguration], Type[PostgresCredentials]]: +def configuration(initial_values: StrAny = None) -> Tuple[RedshiftClientConfiguration, PostgresCredentials]: return ( - make_configuration(RedshiftClientConfiguration, RedshiftClientConfiguration, initial_values=initial_values), - make_configuration(PostgresCredentials, PostgresCredentials, initial_values=initial_values) + make_configuration(RedshiftClientConfiguration(), initial_value=initial_values), + make_configuration(PostgresCredentials(), initial_value=initial_values) ) diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index ef8bda141d..53359f3c91 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -1,25 +1,17 @@ -from typing import Type - from dlt.common.typing import StrAny from dlt.common.data_writers import TLoaderFileFormat from dlt.common.configuration import (PoolRunnerConfiguration, NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration, - ProductionLoadVolumeConfiguration, ProductionNormalizeVolumeConfiguration, - ProductionSchemaVolumeConfiguration, - TPoolType, make_configuration) + TPoolType, make_configuration, configspec) from . import __version__ +@configspec class NormalizeConfiguration(PoolRunnerConfiguration, NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration): - LOADER_FILE_FORMAT: TLoaderFileFormat = "jsonl" # jsonp or insert commands will be generated - POOL_TYPE: TPoolType = "process" - - -class ProductionNormalizeConfiguration(ProductionNormalizeVolumeConfiguration, ProductionLoadVolumeConfiguration, - ProductionSchemaVolumeConfiguration, NormalizeConfiguration): - pass + loader_file_format: TLoaderFileFormat = "jsonl" # jsonp or insert commands will be generated + pool_type: TPoolType = "process" -def configuration(initial_values: StrAny = None) -> Type[NormalizeConfiguration]: - return make_configuration(NormalizeConfiguration, ProductionNormalizeConfiguration, initial_values=initial_values) +def configuration(initial_values: StrAny = None) -> NormalizeConfiguration: + return make_configuration(NormalizeConfiguration(), initial_value=initial_values) diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 81761d881f..07163b2c87 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -7,8 +7,7 @@ from dlt.common.json import custom_pua_decode from dlt.cli import TRunnerArgs from dlt.common.normalizers.json import wrap_in_dict -from dlt.common.runners import TRunMetrics, Runnable, run_pool, initialize_runner, workermethod -from dlt.common.runners.runnable import configuredworker +from dlt.common.runners import TRunMetrics, Runnable, run_pool, initialize_runner from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages import NormalizeStorage, SchemaStorage, LoadStorage @@ -35,7 +34,7 @@ class Normalize(Runnable[ProcessPool]): schema_version_gauge: Gauge = None load_package_counter: Counter = None - def __init__(self, C: Type[NormalizeConfiguration], collector: CollectorRegistry = REGISTRY, schema_storage: SchemaStorage = None) -> None: + def __init__(self, C: NormalizeConfiguration, collector: CollectorRegistry = REGISTRY, schema_storage: SchemaStorage = None) -> None: self.CONFIG = C self.pool: ProcessPool = None self.normalize_storage: NormalizeStorage = None @@ -63,7 +62,7 @@ def create_gauges(registry: CollectorRegistry) -> None: def create_storages(self) -> None: self.normalize_storage = NormalizeStorage(True, self.CONFIG) # normalize saves in preferred format but can read all supported formats - self.load_storage = LoadStorage(True, self.CONFIG, self.CONFIG.LOADER_FILE_FORMAT, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + self.load_storage = LoadStorage(True, self.CONFIG, self.CONFIG.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) @staticmethod @@ -77,10 +76,9 @@ def load_or_create_schema(schema_storage: SchemaStorage, schema_name: str) -> Sc return schema @staticmethod - @configuredworker - def w_normalize_files(CONFIG: Type[NormalizeConfiguration], schema_name: str, load_id: str, extracted_items_files: Sequence[str]) -> Tuple[TSchemaUpdate, int]: + def w_normalize_files(CONFIG: NormalizeConfiguration, schema_name: str, load_id: str, extracted_items_files: Sequence[str]) -> Tuple[TSchemaUpdate, int]: schema = Normalize.load_or_create_schema(SchemaStorage(CONFIG, makedirs=False), schema_name) - load_storage = LoadStorage(False, CONFIG, CONFIG.LOADER_FILE_FORMAT, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + load_storage = LoadStorage(False, CONFIG, CONFIG.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) normalize_storage = NormalizeStorage(False, CONFIG) schema_update: TSchemaUpdate = {} @@ -143,7 +141,7 @@ def w_normalize_files(CONFIG: Type[NormalizeConfiguration], schema_name: str, lo def map_parallel(self, schema_name: str, load_id: str, files: Sequence[str]) -> TMapFuncRV: # TODO: maybe we should chunk by file size, now map all files to workers chunk_files = [files] - param_chunk = [(self.CONFIG.as_dict(), schema_name, load_id, files) for files in chunk_files] + param_chunk = [(self.CONFIG, schema_name, load_id, files) for files in chunk_files] processed_chunks = self.pool.starmap(Normalize.w_normalize_files, param_chunk) return sum([t[1] for t in processed_chunks]), [t[0] for t in processed_chunks], chunk_files diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 4d611fed37..2e55ba00b8 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -45,7 +45,7 @@ def __init__(self, pipeline_name: str, log_level: str = "INFO") -> None: self._loader_instance: Load = None # patch config and initialize pipeline - self.C = make_configuration(PoolRunnerConfiguration, PoolRunnerConfiguration, initial_values={ + self.C = make_configuration(PoolRunnerConfiguration(), initial_value={ "PIPELINE_NAME": pipeline_name, "LOG_LEVEL": log_level, "POOL_TYPE": "None", @@ -178,10 +178,9 @@ def normalize(self, workers: int = 1, max_events_in_chunk: int = 100000) -> int: raise NotImplementedError("Do not use workers in interactive mode ie. in notebook") self._verify_normalize_instance() # set runtime parameters - self._normalize_instance.CONFIG.WORKERS = workers - self._normalize_instance.CONFIG.MAX_EVENTS_IN_CHUNK = max_events_in_chunk + self._normalize_instance.CONFIG.workers = workers # switch to thread pool for single worker - self._normalize_instance.CONFIG.POOL_TYPE = "thread" if workers == 1 else "process" + self._normalize_instance.CONFIG.pool_type = "thread" if workers == 1 else "process" try: ec = runner.run_pool(self._normalize_instance.CONFIG, self._normalize_instance) # in any other case we raise if runner exited with status failed @@ -194,7 +193,7 @@ def normalize(self, workers: int = 1, max_events_in_chunk: int = 100000) -> int: def load(self, max_parallel_loads: int = 20) -> int: self._verify_loader_instance() - self._loader_instance.CONFIG.WORKERS = max_parallel_loads + self._loader_instance.CONFIG.workers = max_parallel_loads self._loader_instance.load_client_cls.CONFIG.DEFAULT_SCHEMA_NAME = self.default_schema_name # type: ignore try: ec = runner.run_pool(self._loader_instance.CONFIG, self._loader_instance) @@ -283,14 +282,14 @@ def sql_client(self, schema_name: str = None) -> SqlClientBase[Any]: if isinstance(c, SqlJobClientBase): return c.sql_client else: - raise SqlClientNotAvailable(self._loader_instance.CONFIG.CLIENT_TYPE) + raise SqlClientNotAvailable(self._loader_instance.CONFIG.client_type) def run_in_pool(self, run_f: Callable[..., Any]) -> int: # internal runners should work in single mode - self._loader_instance.CONFIG.IS_SINGLE_RUN = True - self._loader_instance.CONFIG.EXIT_ON_EXCEPTION = True - self._normalize_instance.CONFIG.IS_SINGLE_RUN = True - self._normalize_instance.CONFIG.EXIT_ON_EXCEPTION = True + self._loader_instance.CONFIG.is_single_run = True + self._loader_instance.CONFIG.exit_on_exception = True + self._normalize_instance.CONFIG.is_single_run = True + self._normalize_instance.CONFIG.exit_on_exception = True def _run(_: Any) -> TRunMetrics: rv = run_f() diff --git a/experiments/pipeline/configuration.py b/experiments/pipeline/configuration.py index f584b0e159..64101b4139 100644 --- a/experiments/pipeline/configuration.py +++ b/experiments/pipeline/configuration.py @@ -1,10 +1,113 @@ -from typing import Any, Type +import os +import inspect +import dataclasses +import tomlkit +from inspect import Signature, Parameter +from typing import Any, List, Type +# from makefun import wraps +from functools import wraps -from dlt.common.typing import DictStrAny, TAny -from dlt.common.configuration.utils import make_configuration +from dlt.common.typing import DictStrAny, StrAny, TAny, TFun +from dlt.common.configuration import BaseConfiguration +from dlt.common.configuration.utils import NON_EVAL_TYPES, make_configuration, SIMPLE_TYPES +# _POS_PARAMETER_KINDS = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, Parameter.VAR_POSITIONAL) -def get_config(spec: Type[TAny], key: str = None, namespace: str = None, initial_values: Any = None, accept_partial: bool = False) -> Type[TAny]: +def _read_toml(file_name: str) -> StrAny: + config_file_path = os.path.abspath(os.path.join(".", "experiments/.dlt", file_name)) + + if os.path.isfile(config_file_path): + with open(config_file_path, "r", encoding="utf-8") as f: + # use whitespace preserving parser + return tomlkit.load(f) + else: + return {} + + +def get_config_from_toml(): + pass + + +def get_config(SPEC: Type[TAny], key: str = None, namespace: str = None, initial_value: Any = None, accept_partial: bool = False) -> TAny: # TODO: implement key and namespace - return make_configuration(spec, spec, initial_values=initial_values, accept_partial=accept_partial) + return make_configuration(SPEC(), initial_value=initial_value, accept_partial=accept_partial) + + +def spec_from_dict(): + pass + + +def spec_from_signature(name: str, sig: Signature) -> Type[BaseConfiguration]: + # synthesize configuration from the signature + fields: List[dataclasses.Field] = [] + for p in sig.parameters.values(): + # skip *args and **kwargs + if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in ["self", "cls"]: + field_type = Any if p.annotation == Parameter.empty else p.annotation + if field_type in SIMPLE_TYPES or field_type in NON_EVAL_TYPES or issubclass(field_type, BaseConfiguration): + field_default = None if p.default == Parameter.empty else dataclasses.field(default=p.default) + if field_default: + # correct the type if Any + if field_type is Any: + field_type = type(p.default) + fields.append((p.name, field_type, field_default)) + else: + fields.append((p.name, field_type)) + print(fields) + SPEC = dataclasses.make_dataclass(name + "_CONFIG", fields, bases=(BaseConfiguration,), init=False) + print("synthesized") + print(SPEC) + # print(SPEC()) + return SPEC + + +def with_config(func = None, /, spec: Type[BaseConfiguration] = None) -> TFun: + + def decorator(f: TFun) -> TFun: + SPEC: Type[BaseConfiguration] = None + sig: Signature = inspect.signature(f) + kwargs_par = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None) + # pos_params = [p.name for p in sig.parameters.values() if p.kind in _POS_PARAMETER_KINDS] + # kw_params = [p.name for p in sig.parameters.values() if p.kind not in _POS_PARAMETER_KINDS] + + if spec is None: + SPEC = spec_from_signature(f.__name__, sig) + else: + SPEC = spec + + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> Any: + # for calls providing all parameters to the func, configuration may not be resolved + # if len(args) + len(kwargs) == len(sig.parameters): + # return f(*args, **kwargs) + + # bind parameters to signature + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + # resolve SPEC + config = get_config(SPEC, SPEC, initial_value=bound_args.arguments) + resolved_params = dict(config) + print("RESOLVED") + print(resolved_params) + # overwrite or add resolved params + for p in sig.parameters.values(): + if p.name in resolved_params: + bound_args.arguments[p.name] = resolved_params.pop(p.name) + # pass all other config parameters into kwargs if present + if kwargs_par is not None: + bound_args.arguments[kwargs_par.name].update(resolved_params) + # call the function with injected config + return f(*bound_args.args, **bound_args.kwargs) + + return _wrap + + # See if we're being called as @with_config or @with_config(). + if func is None: + # We're called with parens. + return decorator + + if not callable(func): + raise ValueError("First parameter to the with_config must be callable ie. by using it as function decorator") + # We're called as @with_config without parens. + return decorator(func) diff --git a/experiments/pipeline/extract.py b/experiments/pipeline/extract.py index e3209b64dc..f7b4a6e12d 100644 --- a/experiments/pipeline/extract.py +++ b/experiments/pipeline/extract.py @@ -1,5 +1,5 @@ import os -from typing import List, Type +from typing import List from dlt.common.utils import uniq_id from dlt.common.sources import TDirectDataItem, TDataItem @@ -15,7 +15,7 @@ class ExtractorStorage(DataItemStorage, NormalizeStorage): EXTRACT_FOLDER = "extract" - def __init__(self, C: Type[NormalizeVolumeConfiguration]) -> None: + def __init__(self, C: NormalizeVolumeConfiguration) -> None: # data item storage with jsonl with pua encoding super().__init__("puae-jsonl", False, C) self.initialize_storage() diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py index b14cb43ba4..e47cecf986 100644 --- a/experiments/pipeline/pipeline.py +++ b/experiments/pipeline/pipeline.py @@ -56,12 +56,11 @@ class TPipelineState(TypedDict): class PipelineConfiguration(RunConfiguration): WORKING_DIR: Optional[str] = None PIPELINE_SECRET: Optional[TSecretValue] = None - drop_existing_data: bool = False + DROP_EXISTING_DATA: bool = False - @classmethod - def check_integrity(cls) -> None: - if cls.PIPELINE_SECRET: - cls.PIPELINE_SECRET = uniq_id() + def check_integrity(self) -> None: + if self.PIPELINE_SECRET: + self.PIPELINE_SECRET = uniq_id() class Pipeline: @@ -157,7 +156,7 @@ def configure(self, **kwargs: Any) -> None: # use system temp folder if not specified if not self.CONFIG.WORKING_DIR: self.CONFIG.WORKING_DIR = tempfile.gettempdir() - self.root_folder = os.path.join(self.CONFIG.WORKING_DIR, self.CONFIG.PIPELINE_NAME) + self.root_folder = os.path.join(self.CONFIG.WORKING_DIR, self.CONFIG.pipeline_name) self._set_common_initial_values() # create pipeline working dir @@ -236,7 +235,7 @@ def extract( # if isinstance(items, str) or isinstance(items, dict) or not # TODO: check if schema exists with self._managed_state(): - default_table_name = table_name or self.CONFIG.PIPELINE_NAME + default_table_name = table_name or self.CONFIG.pipeline_name # TODO: this is not very effective - we consume iterator right away, better implementation needed where we stream iterator to files directly all_items: List[DictStrAny] = [] for item in data: @@ -403,7 +402,7 @@ def _resolve_load_client_config(self) -> Type[LoaderClientDwhConfiguration]: ) def _ensure_destination_name(self) -> str: - d_n = self._resolve_load_client_config().CLIENT_TYPE + d_n = self._resolve_load_client_config().client_type if not d_n: raise PipelineConfigMissing( "destination_name", @@ -413,9 +412,9 @@ def _ensure_destination_name(self) -> str: return d_n def _ensure_default_dataset(self) -> str: - d_n = self._resolve_load_client_config().DEFAULT_DATASET + d_n = self._resolve_load_client_config().default_dataset if not d_n: - d_n = normalize_schema_name(self.CONFIG.PIPELINE_NAME) + d_n = normalize_schema_name(self.CONFIG.pipeline_name) return d_n def _extract_iterator(self, default_table_name: str, items: Sequence[DictStrAny]) -> None: diff --git a/poetry.lock b/poetry.lock index 2a553a87a7..51dfa3b3f7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -304,18 +304,6 @@ typing-extensions = ">=3.7.4.1" all = ["pytz (>=2019.1)"] dates = ["pytz (>=2019.1)"] -[[package]] -name = "fire" -version = "0.4.0" -description = "A library for automatically generating command line interfaces." -category = "main" -optional = false -python-versions = "*" - -[package.dependencies] -six = "*" -termcolor = "*" - [[package]] name = "flake8" version = "5.0.4" @@ -1196,17 +1184,6 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" -[[package]] -name = "randomname" -version = "0.1.5" -description = "Generate random adj-noun names like docker and github." -category = "main" -optional = false -python-versions = "*" - -[package.dependencies] -fire = "*" - [[package]] name = "requests" version = "2.28.1" @@ -1333,17 +1310,6 @@ python-versions = ">=3.8" [package.dependencies] pbr = ">=2.0.0,<2.1.0 || >2.1.0" -[[package]] -name = "termcolor" -version = "2.0.1" -description = "ANSI color formatting for output in terminal" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.extras] -tests = ["pytest-cov", "pytest"] - [[package]] name = "text-unidecode" version = "1.3" @@ -1497,7 +1463,7 @@ redshift = ["psycopg2-binary", "psycopg2cffi"] [metadata] lock-version = "1.1" python-versions = "^3.8,<3.11" -content-hash = "e231693ee02a89e14e8168e0b4d74c284e3f9b0102f211d92efcb62723e19abb" +content-hash = "24bd34ae0bdd70f265ba4fb25f28b084fa55f6f8c9563f482fd79c1d0d695563" [metadata.files] agate = [ @@ -1641,7 +1607,6 @@ decopatch = [ {file = "decopatch-1.4.10.tar.gz", hash = "sha256:957f49c93f4150182c23f8fb51d13bb3213e0f17a79e09c8cca7057598b55720"}, ] domdf-python-tools = [] -fire = [] flake8 = [] flake8-bugbear = [] flake8-builtins = [ @@ -2170,7 +2135,6 @@ pyyaml = [ {file = "PyYAML-5.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db"}, {file = "PyYAML-5.4.1.tar.gz", hash = "sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e"}, ] -randomname = [] requests = [ {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, @@ -2261,7 +2225,6 @@ sqlparse = [ {file = "sqlparse-0.4.2.tar.gz", hash = "sha256:0c00730c74263a94e5a9919ade150dfc3b19c574389985446148402998287dae"}, ] stevedore = [] -termcolor = [] text-unidecode = [ {file = "text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93"}, {file = "text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8"}, diff --git a/pyproject.toml b/pyproject.toml index 9601caebe4..5bdb5c6726 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ GitPython = {version = "^3.1.26", optional = true} dbt-core = {version = "1.0.6", optional = true} dbt-redshift = {version = "1.0.1", optional = true} dbt-bigquery = {version = "1.0.0", optional = true} -randomname = "^0.1.5" tzdata = "^2022.1" tomlkit = "^0.11.3" asyncstdlib = "^3.10.5" diff --git a/tests/common/runners/test_runnable.py b/tests/common/runners/test_runnable.py index eef2771e23..4725fd6fcf 100644 --- a/tests/common/runners/test_runnable.py +++ b/tests/common/runners/test_runnable.py @@ -1,11 +1,8 @@ import gc -from typing import Type import pytest -from multiprocessing import get_start_method from multiprocessing.pool import Pool from multiprocessing.dummy import Pool as ThreadPool -from dlt.common.runners.runnable import configuredworker from dlt.common.utils import uniq_id from dlt.normalize.configuration import NormalizeConfiguration @@ -73,35 +70,20 @@ def test_weak_pool_ref() -> None: def test_configuredworker() -> None: # call worker method with CONFIG values that should be restored into CONFIG type - config = NormalizeConfiguration.as_dict() + config = NormalizeConfiguration() config["import_schema_path"] = "test_schema_path" _worker_1(config, "PX1", par2="PX2") - # may also be called directly - NormT = type("TEST_" + uniq_id(), (NormalizeConfiguration, ), {}) - NormT.IMPORT_SCHEMA_PATH = "test_schema_path" - _worker_1(NormT, "PX1", par2="PX2") - # must also work across process boundary with Pool(1) as p: p.starmap(_worker_1, [(config, "PX1", "PX2")]) - # wrong signature error - with pytest.raises(ValueError): - _wrong_worker_sig(config) - - -@configuredworker -def _wrong_worker_sig(CONFIG: NormalizeConfiguration) -> None: - pass -@configuredworker -def _worker_1(CONFIG: Type[NormalizeConfiguration], par1: str, par2: str = "DEFAULT") -> None: - assert issubclass(CONFIG, NormalizeConfiguration) - # it is a subclass but not the same type - assert not CONFIG is NormalizeConfiguration +def _worker_1(CONFIG: NormalizeConfiguration, par1: str, par2: str = "DEFAULT") -> None: + # a correct type was passed + assert type(CONFIG) is NormalizeConfiguration # check if config values are restored - assert CONFIG.IMPORT_SCHEMA_PATH == "test_schema_path" + assert CONFIG.import_schema_path == "test_schema_path" # check if other parameters are correctly assert par1 == "PX1" assert par2 == "PX2" \ No newline at end of file diff --git a/tests/common/runners/test_runners.py b/tests/common/runners/test_runners.py index fc1e5df55b..666839b9dc 100644 --- a/tests/common/runners/test_runners.py +++ b/tests/common/runners/test_runners.py @@ -5,8 +5,7 @@ from dlt.cli import TRunnerArgs from dlt.common import signals -from dlt.common.typing import StrAny -from dlt.common.configuration import PoolRunnerConfiguration, make_configuration +from dlt.common.configuration import PoolRunnerConfiguration, make_configuration, configspec from dlt.common.configuration.pool_runner_configuration import TPoolType from dlt.common.exceptions import DltException, SignalReceivedException, TimeRangeExhaustedException, UnsupportedProcessStartMethodException from dlt.common.runners import pool_runner as runner @@ -14,34 +13,40 @@ from tests.common.runners.utils import _TestRunnable from tests.utils import init_logger + +@configspec class ModPoolRunnerConfiguration(PoolRunnerConfiguration): - IS_SINGLE_RUN: bool = True - WAIT_RUNS: int = 1 - PIPELINE_NAME: str = "testrunners" - POOL_TYPE: TPoolType = "none" - RUN_SLEEP: float = 0.1 - RUN_SLEEP_IDLE: float = 0.1 - RUN_SLEEP_WHEN_FAILED: float = 0.1 + is_single_run: bool = True + wait_runs: int = 1 + pipeline_name: str = "testrunners" + pool_type: TPoolType = "none" + run_sleep: float = 0.1 + run_sleep_idle: float = 0.1 + run_sleep_when_failed: float = 0.1 +@configspec class StopExceptionRunnerConfiguration(ModPoolRunnerConfiguration): - EXIT_ON_EXCEPTION: bool = True + exit_on_exception: bool = True +@configspec class LimitedPoolRunnerConfiguration(ModPoolRunnerConfiguration): - STOP_AFTER_RUNS: int = 5 + stop_after_runs: int = 5 +@configspec class ProcessPoolConfiguration(ModPoolRunnerConfiguration): - POOL_TYPE: TPoolType = "process" + pool_type: TPoolType = "process" +@configspec class ThreadPoolConfiguration(ModPoolRunnerConfiguration): - POOL_TYPE: TPoolType = "thread" + pool_type: TPoolType = "thread" -def configure(C: Type[PoolRunnerConfiguration], args: TRunnerArgs) -> Type[PoolRunnerConfiguration]: - return make_configuration(C, C, initial_values=args._asdict()) +def configure(C: Type[PoolRunnerConfiguration], args: TRunnerArgs) -> PoolRunnerConfiguration: + return make_configuration(C(), initial_value=args._asdict()) @pytest.fixture(scope="module", autouse=True) diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 4ffb0b2399..5d58b93434 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -21,9 +21,8 @@ @pytest.fixture def schema_storage() -> SchemaStorage: C = make_configuration( - SchemaVolumeConfiguration, - SchemaVolumeConfiguration, - initial_values={ + SchemaVolumeConfiguration(), + initial_value={ "import_schema_path": "tests/common/cases/schemas/rasa", "external_schema_format": "json" }) diff --git a/tests/common/storages/test_loader_storage.py b/tests/common/storages/test_loader_storage.py index 68c75e64e8..d6143fa7e8 100644 --- a/tests/common/storages/test_loader_storage.py +++ b/tests/common/storages/test_loader_storage.py @@ -16,7 +16,7 @@ @pytest.fixture def storage() -> LoadStorage: - C = make_configuration(LoadVolumeConfiguration, LoadVolumeConfiguration) + C = make_configuration(LoadVolumeConfiguration()) s = LoadStorage(True, C, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) return s diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index fe2d7d5d05..547ac67a11 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -26,23 +26,23 @@ def storage() -> SchemaStorage: @pytest.fixture def synced_storage() -> SchemaStorage: # will be created in /schemas - return init_storage({"IMPORT_SCHEMA_PATH": TEST_STORAGE_ROOT + "/import", "EXPORT_SCHEMA_PATH": TEST_STORAGE_ROOT + "/import"}) + return init_storage({"import_schema_path": TEST_STORAGE_ROOT + "/import", "export_schema_path": TEST_STORAGE_ROOT + "/import"}) @pytest.fixture def ie_storage() -> SchemaStorage: # will be created in /schemas - return init_storage({"IMPORT_SCHEMA_PATH": TEST_STORAGE_ROOT + "/import", "EXPORT_SCHEMA_PATH": TEST_STORAGE_ROOT + "/export"}) + return init_storage({"import_schema_path": TEST_STORAGE_ROOT + "/import", "export_schema_path": TEST_STORAGE_ROOT + "/export"}) def init_storage(initial: DictStrAny = None) -> SchemaStorage: - C = make_configuration(SchemaVolumeConfiguration, SchemaVolumeConfiguration, initial_values=initial) + C = make_configuration(SchemaVolumeConfiguration(), initial_value=initial) # use live schema storage for test which must be backward compatible with schema storage s = LiveSchemaStorage(C, makedirs=True) - if C.EXPORT_SCHEMA_PATH: - os.makedirs(C.EXPORT_SCHEMA_PATH, exist_ok=True) - if C.IMPORT_SCHEMA_PATH: - os.makedirs(C.IMPORT_SCHEMA_PATH, exist_ok=True) + if C.export_schema_path: + os.makedirs(C.export_schema_path, exist_ok=True) + if C.import_schema_path: + os.makedirs(C.import_schema_path, exist_ok=True) return s @@ -86,7 +86,7 @@ def test_skip_import_if_not_modified(synced_storage: SchemaStorage, storage: Sch # the import schema gets modified storage_schema.tables["_dlt_loads"]["write_disposition"] = "append" storage_schema.tables.pop("event_user") - synced_storage._export_schema(storage_schema, synced_storage.C.EXPORT_SCHEMA_PATH) + synced_storage._export_schema(storage_schema, synced_storage.C.export_schema_path) # now load will import again reloaded_schema = synced_storage.load_schema("ethereum") # we have overwritten storage schema @@ -113,7 +113,7 @@ def test_store_schema_tampered(synced_storage: SchemaStorage, storage: SchemaSto def test_schema_export(ie_storage: SchemaStorage) -> None: schema = Schema("ethereum") - fs = FileStorage(ie_storage.C.EXPORT_SCHEMA_PATH) + fs = FileStorage(ie_storage.C.export_schema_path) exported_name = ie_storage._file_name_in_store("ethereum", "yaml") # no exported schema assert not fs.has_file(exported_name) @@ -192,7 +192,7 @@ def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: assert schema.version_hash == schema_hash assert schema._imported_version_hash == "njJAySgJRs2TqGWgQXhP+3pCh1A1hXcqe77BpM7JtOU=" # we have simple schema in export folder - fs = FileStorage(ie_storage.C.EXPORT_SCHEMA_PATH) + fs = FileStorage(ie_storage.C.export_schema_path) exported_name = ie_storage._file_name_in_store("ethereum", "yaml") exported_schema = yaml.safe_load(fs.load(exported_name)) assert schema.version_hash == exported_schema["version_hash"] @@ -206,7 +206,7 @@ def test_save_store_schema_over_import_sync(synced_storage: SchemaStorage) -> No synced_storage.save_schema(schema) assert schema._imported_version_hash == "njJAySgJRs2TqGWgQXhP+3pCh1A1hXcqe77BpM7JtOU=" # import schema is overwritten - fs = FileStorage(synced_storage.C.IMPORT_SCHEMA_PATH) + fs = FileStorage(synced_storage.C.import_schema_path) exported_name = synced_storage._file_name_in_store("ethereum", "yaml") exported_schema = yaml.safe_load(fs.load(exported_name)) assert schema.version_hash == exported_schema["version_hash"] == schema_hash diff --git a/tests/common/test_configuration.py b/tests/common/test_configuration.py index ca9da8f928..838e8fb846 100644 --- a/tests/common/test_configuration.py +++ b/tests/common/test_configuration.py @@ -1,14 +1,12 @@ import pytest from os import environ -from typing import Any, Dict, List, NewType, Optional, Tuple +from typing import Any, Dict, List, Mapping, MutableMapping, NewType, Optional, Tuple, Type from dlt.common.typing import TSecretValue from dlt.common.configuration import ( RunConfiguration, ConfigEntryMissingException, ConfigFileNotFoundException, - ConfigEnvValueCannotBeCoercedException, BaseConfiguration, utils) -from dlt.common.configuration.utils import (_coerce_single_value, IS_DEVELOPMENT_CONFIG_KEY, - _get_config_attrs_with_hints, - is_direct_descendant, make_configuration) + ConfigEnvValueCannotBeCoercedException, BaseConfiguration, utils, configspec) +from dlt.common.configuration.utils import make_configuration from dlt.common.configuration.providers import environ as environ_provider from tests.utils import preserve_environ @@ -16,23 +14,21 @@ # used to test version __version__ = "1.0.5" -IS_DEVELOPMENT_CONFIG = 'DEBUG' -NONE_CONFIG_VAR = 'NoneConfigVar' COERCIONS = { - 'STR_VAL': 'test string', - 'INT_VAL': 12345, - 'BOOL_VAL': True, - 'LIST_VAL': [1, "2", [3]], - 'DICT_VAL': { + 'str_val': 'test string', + 'int_val': 12345, + 'bool_val': True, + 'list_val': [1, "2", [3]], + 'dict_val': { 'a': 1, "b": "2" }, - 'TUPLE_VAL': (1, 2, '7'), - 'SET_VAL': {1, 2, 3}, - 'BYTES_VAL': b'Hello World!', - 'FLOAT_VAL': 1.18927, - 'ANY_VAL': "function() {}", - 'NONE_VAL': "none", + 'tuple_val': (1, 2, '7'), + 'set_val': {1, 2, 3}, + 'bytes_val': b'Hello World!', + 'float_val': 1.18927, + 'any_val': "function() {}", + 'none_val': "none", 'COMPLEX_VAL': { "_": (1440, ["*"], []), "change-email": (560, ["*"], []) @@ -41,98 +37,115 @@ INVALID_COERCIONS = { # 'STR_VAL': 'test string', # string always OK - 'INT_VAL': "a12345", - 'BOOL_VAL': "Yes", # bool overridden by string - that is the most common problem - 'LIST_VAL': {1, "2", 3.0}, - 'DICT_VAL': "{'a': 1, 'b', '2'}", - 'TUPLE_VAL': [1, 2, '7'], - 'SET_VAL': [1, 2, 3], - 'BYTES_VAL': 'Hello World!', - 'FLOAT_VAL': "invalid" + 'int_val': "a12345", + 'bool_val': "Yes", # bool overridden by string - that is the most common problem + 'list_val': {1, "2", 3.0}, + 'dict_val': "{'a': 1, 'b', '2'}", + 'tuple_val': [1, 2, '7'], + 'set_val': [1, 2, 3], + 'bytes_val': 'Hello World!', + 'float_val': "invalid" } EXCEPTED_COERCIONS = { # allows to use int for float - 'FLOAT_VAL': 10, + 'float_val': 10, # allows to use float for str - 'STR_VAL': 10.0 + 'str_val': 10.0 } COERCED_EXCEPTIONS = { # allows to use int for float - 'FLOAT_VAL': 10.0, + 'float_val': 10.0, # allows to use float for str - 'STR_VAL': "10.0" + 'str_val': "10.0" } +@configspec class SimpleConfiguration(RunConfiguration): - PIPELINE_NAME: str = "Some Name" + pipeline_name: str = "Some Name" + test_bool: bool = False +@configspec class WrongConfiguration(RunConfiguration): - PIPELINE_NAME: str = "Some Name" - NoneConfigVar = None - LOG_COLOR: bool = True + pipeline_name: str = "Some Name" + NoneConfigVar: str = None + log_color: bool = True +@configspec class SecretConfiguration(RunConfiguration): - PIPELINE_NAME: str = "secret" - SECRET_VALUE: TSecretValue = None + pipeline_name: str = "secret" + secret_value: TSecretValue = None +@configspec class SecretKubeConfiguration(RunConfiguration): - PIPELINE_NAME: str = "secret kube" - SECRET_KUBE: TSecretValue = None - - -class TestCoercionConfiguration(RunConfiguration): - PIPELINE_NAME: str = "Some Name" - STR_VAL: str = None - INT_VAL: int = None - BOOL_VAL: bool = None - LIST_VAL: list = None # type: ignore - DICT_VAL: dict = None # type: ignore - TUPLE_VAL: tuple = None # type: ignore - BYTES_VAL: bytes = None - SET_VAL: set = None # type: ignore - FLOAT_VAL: float = None - ANY_VAL: Any = None - NONE_VAL = None + pipeline_name: str = "secret kube" + secret_kube: TSecretValue = None + + +@configspec +class CoercionTestConfiguration(RunConfiguration): + pipeline_name: str = "Some Name" + str_val: str = None + int_val: int = None + bool_val: bool = None + list_val: list = None # type: ignore + dict_val: dict = None # type: ignore + tuple_val: tuple = None # type: ignore + bytes_val: bytes = None + set_val: set = None # type: ignore + float_val: float = None + any_val: Any = None + none_val: str = None COMPLEX_VAL: Dict[str, Tuple[int, List[str], List[str]]] = None +@configspec class VeryWrongConfiguration(WrongConfiguration): - PIPELINE_NAME: str = "Some Name" - STR_VAL: str = "" - INT_VAL: int = None - LOG_COLOR: str = "1" # type: ignore + pipeline_name: str = "Some Name" + str_val: str = "" + int_val: int = None + log_color: str = "1" # type: ignore +@configspec class ConfigurationWithOptionalTypes(RunConfiguration): - PIPELINE_NAME: str = "Some Name" + pipeline_name: str = "Some Name" - STR_VAL: Optional[str] = None - INT_VAL: Optional[int] = None - BOOL_VAL: bool = True + str_val: Optional[str] = None + int_val: Optional[int] = None + bool_val: bool = True +@configspec class ProdConfigurationWithOptionalTypes(ConfigurationWithOptionalTypes): - PROD_VAL: str = "prod" + prod_val: str = "prod" +@configspec class MockProdConfiguration(RunConfiguration): - PIPELINE_NAME: str = "comp" + pipeline_name: str = "comp" +@configspec class MockProdConfigurationVar(RunConfiguration): - PIPELINE_NAME: str = "comp" + pipeline_name: str = "comp" +@configspec class NamespacedConfiguration(BaseConfiguration): __namespace__ = "DLT_TEST" - PASSWORD: str = None + password: str = None + + +@configspec(init=True) +class FieldWithNoDefaultConfiguration(RunConfiguration): + no_default: str LongInteger = NewType("LongInteger", int) @@ -147,157 +160,208 @@ def environment() -> Any: def test_run_configuration_gen_name(environment: Any) -> None: - C = make_configuration(RunConfiguration, RunConfiguration) - assert C.PIPELINE_NAME.startswith("dlt_") + C = make_configuration(RunConfiguration()) + assert C.pipeline_name.startswith("dlt_") -def test_configuration_to_dict(environment: Any) -> None: +def test_configuration_is_mutable_mapping(environment: Any) -> None: + # configurations provide full MutableMapping support + # here order of items in dict matters expected_dict = { - 'CONFIG_FILES_STORAGE_PATH': '_storage/config/%s', - 'IS_DEVELOPMENT_CONFIG': True, - 'LOG_FORMAT': '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}', - 'LOG_LEVEL': 'DEBUG', - 'PIPELINE_NAME': 'secret', - 'PROMETHEUS_PORT': None, - 'REQUEST_TIMEOUT': (15, 300), - 'SECRET_VALUE': None, - 'SENTRY_DSN': None + 'pipeline_name': 'secret', + 'sentry_dsn': None, + 'prometheus_port': None, + 'log_format': '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}', + 'log_level': 'DEBUG', + 'request_timeout': (15, 300), + 'config_files_storage_path': '_storage/config/%s', + 'secret_value': None } - assert SecretConfiguration.as_dict() == {k.lower():v for k,v in expected_dict.items()} - assert SecretConfiguration.as_dict(lowercase=False) == expected_dict + assert dict(SecretConfiguration()) == expected_dict environment["SECRET_VALUE"] = "secret" - C = make_configuration(SecretConfiguration, SecretConfiguration) - d = C.as_dict(lowercase=False) - expected_dict["_VERSION"] = d["_VERSION"] - expected_dict["SECRET_VALUE"] = "secret" - assert d == expected_dict + C = make_configuration(SecretConfiguration()) + expected_dict["secret_value"] = "secret" + assert dict(C) == expected_dict + + # check mutable mapping type + assert isinstance(C, MutableMapping) + assert isinstance(C, Mapping) + assert not isinstance(C, Dict) + + # check view ops + assert C.keys() == expected_dict.keys() + assert len(C) == len(expected_dict) + assert C.items() == expected_dict.items() + assert list(C.values()) == list(expected_dict.values()) + for key in C: + assert C[key] == expected_dict[key] + # version is present as attr but not present in dict + assert hasattr(C, "_version") + assert hasattr(C, "__is_partial__") + assert hasattr(C, "__namespace__") + + with pytest.raises(KeyError): + C["_version"] + + # set ops + # update supported and non existing attributes are ignored + C.update({"pipeline_name": "old pipe", "__version": "1.1.1"}) + assert C.pipeline_name == "old pipe" == C["pipeline_name"] + assert C._version != "1.1.1" + + # delete is not supported + with pytest.raises(NotImplementedError): + del C["pipeline_name"] + + with pytest.raises(NotImplementedError): + C.pop("pipeline_name", None) + + # setting supported + C["pipeline_name"] = "new pipe" + assert C.pipeline_name == "new pipe" == C["pipeline_name"] + + with pytest.raises(KeyError): + C["_version"] = "1.1.1" + + +def test_fields_with_no_default_to_null() -> None: + # fields with no default are promoted to class attrs with none + assert FieldWithNoDefaultConfiguration.no_default is None + assert FieldWithNoDefaultConfiguration().no_default is None + + +def test_init_method_gen() -> None: + C = FieldWithNoDefaultConfiguration(no_default="no_default", sentry_dsn="SENTRY") + assert C.no_default == "no_default" + assert C.sentry_dsn == "SENTRY" + +def test_multi_derivation_defaults() -> None: -def test_configuration_rise_exception_when_config_is_not_complete() -> None: + @configspec + class MultiConfiguration(MockProdConfiguration, ConfigurationWithOptionalTypes, NamespacedConfiguration): + pass + + # apparently dataclasses set default in reverse mro so MockProdConfiguration overwrites + C = MultiConfiguration() + assert C.pipeline_name == MultiConfiguration.pipeline_name == "comp" + # but keys are ordered in MRO so password from NamespacedConfiguration goes first + keys = list(C.keys()) + assert keys[0] == "password" + assert keys[-1] == "bool_val" + assert C.__namespace__ == "DLT_TEST" + + +def test_raises_on_unresolved_fields() -> None: with pytest.raises(ConfigEntryMissingException) as config_entry_missing_exception: - keys = _get_config_attrs_with_hints(WrongConfiguration) - utils._is_config_bounded(WrongConfiguration, keys) + C = WrongConfiguration() + keys = utils._get_resolvable_fields(C) + utils._is_config_bounded(C, keys) - assert 'NoneConfigVar' in config_entry_missing_exception.value.missing_set + assert 'NONECONFIGVAR' in config_entry_missing_exception.value.missing_set + + # via make configuration + with pytest.raises(ConfigEntryMissingException) as config_entry_missing_exception: + make_configuration(WrongConfiguration()) + assert 'NONECONFIGVAR' in config_entry_missing_exception.value.missing_set def test_optional_types_are_not_required() -> None: # this should not raise an exception - keys = _get_config_attrs_with_hints(ConfigurationWithOptionalTypes) - utils._is_config_bounded(ConfigurationWithOptionalTypes, keys) + keys = utils._get_resolvable_fields(ConfigurationWithOptionalTypes()) + utils._is_config_bounded(ConfigurationWithOptionalTypes(), keys) # make optional config - make_configuration(ConfigurationWithOptionalTypes, ConfigurationWithOptionalTypes) + make_configuration(ConfigurationWithOptionalTypes()) # make config with optional values - make_configuration( - ProdConfigurationWithOptionalTypes, - ProdConfigurationWithOptionalTypes, - initial_values={"INT_VAL": None} - ) + make_configuration(ProdConfigurationWithOptionalTypes(), initial_value={"INT_VAL": None}) def test_configuration_apply_adds_environment_variable_to_config(environment: Any) -> None: - environment[NONE_CONFIG_VAR] = "Some" + environment["NONECONFIGVAR"] = "Some" - keys = _get_config_attrs_with_hints(WrongConfiguration) - utils._apply_environ_to_config(WrongConfiguration, keys) - utils._is_config_bounded(WrongConfiguration, keys) + C = WrongConfiguration() + keys = utils._get_resolvable_fields(C) + utils._resolve_config_fields(C, keys, accept_partial=False) + utils._is_config_bounded(C, keys) - # NoneConfigVar has no hint so value not coerced from string - assert WrongConfiguration.NoneConfigVar == environment[NONE_CONFIG_VAR] + assert C.NoneConfigVar == environment["NONECONFIGVAR"] -def test_configuration_resolve(environment: Any) -> None: - environment[IS_DEVELOPMENT_CONFIG] = 'True' +def test_configuration_resolve_env_var(environment: Any) -> None: + environment["TEST_BOOL"] = 'True' - keys = _get_config_attrs_with_hints(SimpleConfiguration) - utils._apply_environ_to_config(SimpleConfiguration, keys) - utils._is_config_bounded(SimpleConfiguration, keys) + C = SimpleConfiguration() + keys = utils._get_resolvable_fields(C) + utils._resolve_config_fields(C, keys, accept_partial=False) + utils._is_config_bounded(C, keys) # value will be coerced to bool - assert RunConfiguration.IS_DEVELOPMENT_CONFIG is True + assert C.test_bool is True def test_find_all_keys() -> None: - keys = _get_config_attrs_with_hints(VeryWrongConfiguration) - # assert hints and types: NoneConfigVar has no type hint and LOG_COLOR had it hint overwritten in derived class - assert set({'STR_VAL': str, 'INT_VAL': int, 'NoneConfigVar': None, 'LOG_COLOR': str}.items()).issubset(keys.items()) + keys = utils._get_resolvable_fields(VeryWrongConfiguration()) + # assert hints and types: LOG_COLOR had it hint overwritten in derived class + assert set({'str_val': str, 'int_val': int, 'NoneConfigVar': str, 'log_color': str}.items()).issubset(keys.items()) def test_coercions(environment: Any) -> None: for key, value in COERCIONS.items(): - environment[key] = str(value) + environment[key.upper()] = str(value) - keys = _get_config_attrs_with_hints(TestCoercionConfiguration) - utils._apply_environ_to_config(TestCoercionConfiguration, keys) - utils._is_config_bounded(TestCoercionConfiguration, keys) + C = CoercionTestConfiguration() + keys = utils._get_resolvable_fields(C) + utils._resolve_config_fields(C, keys, accept_partial=False) + utils._is_config_bounded(C, keys) for key in COERCIONS: - assert getattr(TestCoercionConfiguration, key) == COERCIONS[key] + assert getattr(C, key) == COERCIONS[key] def test_invalid_coercions(environment: Any) -> None: - config_keys = _get_config_attrs_with_hints(TestCoercionConfiguration) + C = CoercionTestConfiguration() + config_keys = utils._get_resolvable_fields(C) for key, value in INVALID_COERCIONS.items(): try: - environment[key] = str(value) - utils._apply_environ_to_config(TestCoercionConfiguration, config_keys) + environment[key.upper()] = str(value) + utils._resolve_config_fields(C, config_keys, accept_partial=False) except ConfigEnvValueCannotBeCoercedException as coerc_exc: - # must fail excatly on expected value + # must fail exactly on expected value if coerc_exc.attr_name != key: raise # overwrite with valid value and go to next env - environment[key] = str(COERCIONS[key]) + environment[key.upper()] = str(COERCIONS[key]) continue raise AssertionError("%s was coerced with %s which is invalid type" % (key, value)) def test_excepted_coercions(environment: Any) -> None: - config_keys = _get_config_attrs_with_hints(TestCoercionConfiguration) + C = CoercionTestConfiguration() + config_keys = utils._get_resolvable_fields(C) for k, v in EXCEPTED_COERCIONS.items(): - environment[k] = str(v) - utils._apply_environ_to_config(TestCoercionConfiguration, config_keys) + environment[k.upper()] = str(v) + utils._resolve_config_fields(C, config_keys, accept_partial=False) for key in EXCEPTED_COERCIONS: - assert getattr(TestCoercionConfiguration, key) == COERCED_EXCEPTIONS[key] - - -def test_development_config_detection(environment: Any) -> None: - # default is true - assert utils._is_development_config() - environment[IS_DEVELOPMENT_CONFIG_KEY] = "False" - # explicit values - assert not utils._is_development_config() - environment[IS_DEVELOPMENT_CONFIG_KEY] = "True" - assert utils._is_development_config() - # raise exception on env value that cannot be coerced to bool - with pytest.raises(ConfigEnvValueCannotBeCoercedException): - environment[IS_DEVELOPMENT_CONFIG_KEY] = "NONBOOL" - utils._is_development_config() + assert getattr(C, key) == COERCED_EXCEPTIONS[key] def test_make_configuration(environment: Any) -> None: # fill up configuration - environment['INT_VAL'] = "1" - C = utils.make_configuration(WrongConfiguration, VeryWrongConfiguration) + environment["NONECONFIGVAR"] = "1" + C = utils.make_configuration(WrongConfiguration()) assert not C.__is_partial__ - # default is true - assert is_direct_descendant(C, WrongConfiguration) - environment[IS_DEVELOPMENT_CONFIG_KEY] = "False" - assert is_direct_descendant(utils.make_configuration(WrongConfiguration, VeryWrongConfiguration), VeryWrongConfiguration) - environment[IS_DEVELOPMENT_CONFIG_KEY] = "True" - assert is_direct_descendant(utils.make_configuration(WrongConfiguration, VeryWrongConfiguration), WrongConfiguration) + assert C.NoneConfigVar == "1" def test_auto_derivation(environment: Any) -> None: - # make_configuration auto derives a type and never modifies the original type + # make_configuration works on instances of dataclasses and types are not modified environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration, SecretConfiguration) + C = utils.make_configuration(SecretConfiguration()) # auto derived type holds the value - assert C.SECRET_VALUE == "1" + assert C.secret_value == "1" # base type is untouched - assert SecretConfiguration.SECRET_VALUE is None - # type name is derived - assert C.__name__.startswith("SecretConfiguration_") + assert SecretConfiguration.secret_value is None def test_initial_values(environment: Any) -> None: @@ -305,22 +369,24 @@ def test_initial_values(environment: Any) -> None: environment["PIPELINE_NAME"] = "env name" environment["CREATED_VAL"] = "12837" # set initial values and allow partial config - C = make_configuration(TestCoercionConfiguration, TestCoercionConfiguration, - {"PIPELINE_NAME": "initial name", "NONE_VAL": type(environment), "CREATED_VAL": 878232, "BYTES_VAL": b"str"}, + C = make_configuration(CoercionTestConfiguration(), + {"pipeline_name": "initial name", "none_val": type(environment), "created_val": 878232, "bytes_val": b"str"}, accept_partial=True ) # from env - assert C.PIPELINE_NAME == "env name" + assert C.pipeline_name == "env name" # from initial - assert C.BYTES_VAL == b"str" - assert C.NONE_VAL == type(environment) + assert C.bytes_val == b"str" + assert C.none_val == type(environment) # new prop overridden from env assert environment["CREATED_VAL"] == "12837" def test_accept_partial(environment: Any) -> None: + # modify original type WrongConfiguration.NoneConfigVar = None - C = make_configuration(WrongConfiguration, WrongConfiguration, accept_partial=True) + # that None value will be present in the instance + C = make_configuration(WrongConfiguration(), accept_partial=True) assert C.NoneConfigVar is None # partial resolution assert C.__is_partial__ @@ -330,24 +396,22 @@ def test_finds_version(environment: Any) -> None: global __version__ v = __version__ - C = utils.make_configuration(SimpleConfiguration, SimpleConfiguration) - assert C._VERSION == v + C = utils.make_configuration(SimpleConfiguration()) + assert C._version == v try: del globals()["__version__"] - # C is a type, not instance and holds the _VERSION from previous extract - delattr(C, "_VERSION") - C = utils.make_configuration(SimpleConfiguration, SimpleConfiguration) - assert not hasattr(C, "_VERSION") + C = utils.make_configuration(SimpleConfiguration()) + assert not hasattr(C, "_version") finally: __version__ = v def test_secret(environment: Any) -> None: with pytest.raises(ConfigEntryMissingException): - utils.make_configuration(SecretConfiguration, SecretConfiguration) + utils.make_configuration(SecretConfiguration()) environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration, SecretConfiguration) - assert C.SECRET_VALUE == "1" + C = utils.make_configuration(SecretConfiguration()) + assert C.secret_value == "1" # mock the path to point to secret storage # from dlt.common.configuration import config_utils path = environ_provider.SECRET_STORAGE_PATH @@ -355,19 +419,19 @@ def test_secret(environment: Any) -> None: try: # must read a secret file environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = utils.make_configuration(SecretConfiguration, SecretConfiguration) - assert C.SECRET_VALUE == "BANANA" + C = utils.make_configuration(SecretConfiguration()) + assert C.secret_value == "BANANA" # set some weird path, no secret file at all del environment['SECRET_VALUE'] environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" with pytest.raises(ConfigEntryMissingException): - utils.make_configuration(SecretConfiguration, SecretConfiguration) + utils.make_configuration(SecretConfiguration()) # set env which is a fallback for secret not as file environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration, SecretConfiguration) - assert C.SECRET_VALUE == "1" + C = utils.make_configuration(SecretConfiguration()) + assert C.secret_value == "1" finally: environ_provider.SECRET_STORAGE_PATH = path @@ -376,65 +440,44 @@ def test_secret_kube_fallback(environment: Any) -> None: path = environ_provider.SECRET_STORAGE_PATH try: environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = utils.make_configuration(SecretKubeConfiguration, SecretKubeConfiguration) + C = utils.make_configuration(SecretKubeConfiguration()) # all unix editors will add x10 at the end of file, it will be preserved - assert C.SECRET_KUBE == "kube\n" + assert C.secret_kube == "kube\n" # we propagate secrets back to environ and strip the whitespace assert environment['SECRET_KUBE'] == "kube" finally: environ_provider.SECRET_STORAGE_PATH = path -def test_configuration_must_be_subclass_of_prod(environment: Any) -> None: - # fill up configuration - environment['INT_VAL'] = "1" - # prod must inherit from config - with pytest.raises(AssertionError): - # VeryWrongConfiguration does not descend inherit from ConfigurationWithOptionalTypes so it cannot be production config of it - utils.make_configuration(ConfigurationWithOptionalTypes, VeryWrongConfiguration) - - def test_coerce_values() -> None: with pytest.raises(ConfigEnvValueCannotBeCoercedException): - _coerce_single_value("key", "some string", int) - assert _coerce_single_value("key", "some string", str) == "some string" + coerce_single_value("key", "some string", int) + assert coerce_single_value("key", "some string", str) == "some string" # Optional[str] has type object, mypy will never work properly... - assert _coerce_single_value("key", "some string", Optional[str]) == "some string" # type: ignore + assert coerce_single_value("key", "some string", Optional[str]) == "some string" # type: ignore - assert _coerce_single_value("key", "234", int) == 234 - assert _coerce_single_value("key", "234", Optional[int]) == 234 # type: ignore + assert coerce_single_value("key", "234", int) == 234 + assert coerce_single_value("key", "234", Optional[int]) == 234 # type: ignore # check coercions of NewTypes - assert _coerce_single_value("key", "test str X", FirstOrderStr) == "test str X" - assert _coerce_single_value("key", "test str X", Optional[FirstOrderStr]) == "test str X" # type: ignore - assert _coerce_single_value("key", "test str X", Optional[SecondOrderStr]) == "test str X" # type: ignore - assert _coerce_single_value("key", "test str X", SecondOrderStr) == "test str X" - assert _coerce_single_value("key", "234", LongInteger) == 234 - assert _coerce_single_value("key", "234", Optional[LongInteger]) == 234 # type: ignore + assert coerce_single_value("key", "test str X", FirstOrderStr) == "test str X" + assert coerce_single_value("key", "test str X", Optional[FirstOrderStr]) == "test str X" # type: ignore + assert coerce_single_value("key", "test str X", Optional[SecondOrderStr]) == "test str X" # type: ignore + assert coerce_single_value("key", "test str X", SecondOrderStr) == "test str X" + assert coerce_single_value("key", "234", LongInteger) == 234 + assert coerce_single_value("key", "234", Optional[LongInteger]) == 234 # type: ignore # this coercion should fail with pytest.raises(ConfigEnvValueCannotBeCoercedException): - _coerce_single_value("key", "some string", LongInteger) + coerce_single_value("key", "some string", LongInteger) with pytest.raises(ConfigEnvValueCannotBeCoercedException): - _coerce_single_value("key", "some string", Optional[LongInteger]) # type: ignore - - -def test_configuration_files_prod_path(environment: Any) -> None: - environment[IS_DEVELOPMENT_CONFIG_KEY] = "True" - C = utils.make_configuration(MockProdConfiguration, MockProdConfiguration) - assert C.CONFIG_FILES_STORAGE_PATH == "_storage/config/%s" - - environment[IS_DEVELOPMENT_CONFIG_KEY] = "False" - C = utils.make_configuration(MockProdConfiguration, MockProdConfiguration) - assert C.IS_DEVELOPMENT_CONFIG is False - assert C.CONFIG_FILES_STORAGE_PATH == "/run/config/%s" + coerce_single_value("key", "some string", Optional[LongInteger]) # type: ignore def test_configuration_files(environment: Any) -> None: # overwrite config file paths - environment[IS_DEVELOPMENT_CONFIG_KEY] = "False" environment["CONFIG_FILES_STORAGE_PATH"] = "./tests/common/cases/schemas/ev1/%s" - C = utils.make_configuration(MockProdConfigurationVar, MockProdConfigurationVar) - assert C.CONFIG_FILES_STORAGE_PATH == environment["CONFIG_FILES_STORAGE_PATH"] + C = utils.make_configuration(MockProdConfigurationVar()) + assert C.config_files_storage_path == environment["CONFIG_FILES_STORAGE_PATH"] assert C.has_configuration_file("hasn't") is False assert C.has_configuration_file("event_schema.json") is True assert C.get_configuration_file_path("event_schema.json") == "./tests/common/cases/schemas/ev1/event_schema.json" @@ -446,18 +489,22 @@ def test_configuration_files(environment: Any) -> None: def test_namespaced_configuration(environment: Any) -> None: with pytest.raises(ConfigEntryMissingException) as exc_val: - utils.make_configuration(NamespacedConfiguration, NamespacedConfiguration) + utils.make_configuration(NamespacedConfiguration()) assert exc_val.value.missing_set == ["DLT_TEST__PASSWORD"] assert exc_val.value.namespace == "DLT_TEST" # init vars work without namespace - C = utils.make_configuration(NamespacedConfiguration, NamespacedConfiguration, initial_values={"PASSWORD": "PASS"}) - assert C.PASSWORD == "PASS" + C = utils.make_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) + assert C.password == "PASS" # env var must be prefixed environment["PASSWORD"] = "PASS" with pytest.raises(ConfigEntryMissingException) as exc_val: - utils.make_configuration(NamespacedConfiguration, NamespacedConfiguration) + utils.make_configuration(NamespacedConfiguration()) environment["DLT_TEST__PASSWORD"] = "PASS" - C = utils.make_configuration(NamespacedConfiguration, NamespacedConfiguration) - assert C.PASSWORD == "PASS" + C = utils.make_configuration(NamespacedConfiguration()) + assert C.password == "PASS" + +def coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: + hint = utils._extract_simple_type(hint) + return utils._coerce_single_value(key, value, hint) diff --git a/tests/common/test_logging.py b/tests/common/test_logging.py index ef3ef21040..8d0c95f3b4 100644 --- a/tests/common/test_logging.py +++ b/tests/common/test_logging.py @@ -6,29 +6,34 @@ from dlt import __version__ as auto_version from dlt.common import logger, sleep from dlt.common.typing import StrStr -from dlt.common.configuration import RunConfiguration +from dlt.common.configuration import RunConfiguration, configspec from tests.utils import preserve_environ +@configspec class PureBasicConfiguration(RunConfiguration): - PIPELINE_NAME: str = "logger" + pipeline_name: str = "logger" +@configspec class PureBasicConfigurationProc(PureBasicConfiguration): - _VERSION: str = "1.6.6" + _version: str = "1.6.6" +@configspec class JsonLoggerConfiguration(PureBasicConfigurationProc): - LOG_FORMAT: str = "JSON" + log_format: str = "JSON" +@configspec class SentryLoggerConfiguration(JsonLoggerConfiguration): - SENTRY_DSN: str = "http://user:pass@localhost/818782" + sentry_dsn: str = "http://user:pass@localhost/818782" +@configspec(init=True) class SentryLoggerCriticalConfiguration(SentryLoggerConfiguration): - LOG_LEVEL: str = "CRITICAL" + log_level: str = "CRITICAL" @pytest.fixture(scope="function") @@ -39,14 +44,14 @@ def environment() -> StrStr: def test_version_extract(environment: StrStr) -> None: - version = logger._extract_version_info(PureBasicConfiguration) + version = logger._extract_version_info(PureBasicConfiguration()) # if component ver not avail use system version assert version == {'version': auto_version, 'component_name': 'logger'} - version = logger._extract_version_info(PureBasicConfigurationProc) - assert version["component_version"] == PureBasicConfigurationProc._VERSION + version = logger._extract_version_info(PureBasicConfigurationProc()) + assert version["component_version"] == PureBasicConfigurationProc()._version # mock image info available in container _mock_image_env(environment) - version = logger._extract_version_info(PureBasicConfigurationProc) + version = logger._extract_version_info(PureBasicConfigurationProc()) assert version == {'version': auto_version, 'commit_sha': '192891', 'component_name': 'logger', 'component_version': '1.6.6', 'image_version': 'scale/v:112'} @@ -62,7 +67,7 @@ def test_pod_info_extract(environment: StrStr) -> None: def test_text_logger_init(environment: StrStr) -> None: _mock_image_env(environment) _mock_pod_env(environment) - logger.init_logging_from_config(PureBasicConfigurationProc) + logger.init_logging_from_config(PureBasicConfigurationProc()) logger.health("HEALTH data", extra={"metrics": "props"}) logger.metrics("METRICS data", extra={"metrics": "props"}) logger.warning("Warning message here") @@ -89,17 +94,13 @@ def test_json_logger_init(environment: StrStr) -> None: def test_sentry_log_level() -> None: - SentryLoggerCriticalConfiguration.LOG_LEVEL = "CRITICAL" - sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration) + sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="CRITICAL")) assert sll._handler.level == logging._nameToLevel["CRITICAL"] - SentryLoggerCriticalConfiguration.LOG_LEVEL = "ERROR" - sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration) + sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="ERROR")) assert sll._handler.level == logging._nameToLevel["ERROR"] - SentryLoggerCriticalConfiguration.LOG_LEVEL = "WARNING" - sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration) + sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="WARNING")) assert sll._handler.level == logging._nameToLevel["WARNING"] - SentryLoggerCriticalConfiguration.LOG_LEVEL = "INFO" - sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration) + sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="INFO")) assert sll._handler.level == logging._nameToLevel["WARNING"] @@ -107,7 +108,7 @@ def test_sentry_log_level() -> None: def test_sentry_init(environment: StrStr) -> None: _mock_image_env(environment) _mock_pod_env(environment) - logger.init_logging_from_config(SentryLoggerConfiguration) + logger.init_logging_from_config(SentryLoggerConfiguration()) try: 1 / 0 except ZeroDivisionError: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..75eb19d185 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,17 @@ +import os + +def pytest_configure(config): + # patch the configurations to use test storage by default, we modify the types (classes) fields + # the dataclass implementation will use those patched values when creating instances (the values present + # in the declaration are not frozen allowing patching) + + from dlt.common.configuration import RunConfiguration, LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration + + test_storage_root = "_storage" + RunConfiguration.config_files_storage_path = os.path.join(test_storage_root, "config/%s") + LoadVolumeConfiguration.load_volume_path = os.path.join(test_storage_root, "load") + NormalizeVolumeConfiguration.normalize_volume_path = os.path.join(test_storage_root, "normalize") + SchemaVolumeConfiguration.schema_volume_path = os.path.join(test_storage_root, "schemas") + + assert RunConfiguration.config_files_storage_path == os.path.join(test_storage_root, "config/%s") + assert RunConfiguration().config_files_storage_path == os.path.join(test_storage_root, "config/%s") diff --git a/tests/dbt_runner/test_runner_bigquery.py b/tests/dbt_runner/test_runner_bigquery.py index 24756c3575..8be838a012 100644 --- a/tests/dbt_runner/test_runner_bigquery.py +++ b/tests/dbt_runner/test_runner_bigquery.py @@ -57,8 +57,8 @@ def test_create_folders() -> None: setup_runner("eks_dev_dest", override_values={ "SOURCE_SCHEMA_PREFIX": "carbon_bot_3", "PACKAGE_ADDITIONAL_VARS": {"add_var_name": "add_var_value"}, - "LOG_FORMAT": "JSON", - "LOG_LEVEL": "INFO" + "log_format": "JSON", + "log_level": "INFO" }) assert runner.repo_path.endswith(runner.CLONED_PACKAGE_NAME) diff --git a/tests/dbt_runner/test_runner_redshift.py b/tests/dbt_runner/test_runner_redshift.py index d2fc08743e..d5b5b88081 100644 --- a/tests/dbt_runner/test_runner_redshift.py +++ b/tests/dbt_runner/test_runner_redshift.py @@ -62,26 +62,24 @@ def module_autouse() -> None: def test_configuration() -> None: # check names normalized C = make_configuration( - DBTRunnerConfiguration, - DBTRunnerConfiguration, - initial_values={"PACKAGE_REPOSITORY_SSH_KEY": "---NO NEWLINE---", "SOURCE_SCHEMA_PREFIX": "schema"} + DBTRunnerConfiguration(), + initial_value={"PACKAGE_REPOSITORY_SSH_KEY": "---NO NEWLINE---", "SOURCE_SCHEMA_PREFIX": "schema"} ) - assert C.PACKAGE_REPOSITORY_SSH_KEY == "---NO NEWLINE---\n" + assert C.package_repository_ssh_key == "---NO NEWLINE---\n" C = make_configuration( - DBTRunnerConfiguration, - DBTRunnerConfiguration, - initial_values={"PACKAGE_REPOSITORY_SSH_KEY": "---WITH NEWLINE---\n", "SOURCE_SCHEMA_PREFIX": "schema"} + DBTRunnerConfiguration(), + initial_value={"PACKAGE_REPOSITORY_SSH_KEY": "---WITH NEWLINE---\n", "SOURCE_SCHEMA_PREFIX": "schema"} ) - assert C.PACKAGE_REPOSITORY_SSH_KEY == "---WITH NEWLINE---\n" + assert C.package_repository_ssh_key == "---WITH NEWLINE---\n" def test_create_folders() -> None: setup_runner("eks_dev_dest", override_values={ "SOURCE_SCHEMA_PREFIX": "carbon_bot_3", "PACKAGE_ADDITIONAL_VARS": {"add_var_name": "add_var_value"}, - "LOG_FORMAT": "JSON", - "LOG_LEVEL": "INFO" + "log_format": "JSON", + "log_level": "INFO" }) assert runner.repo_path.endswith(runner.CLONED_PACKAGE_NAME) assert runner.profile_name == "rasa_semantic_schema_redshift" @@ -94,7 +92,7 @@ def test_initialize_package_wrong_key() -> None: # private repo "PACKAGE_REPOSITORY_URL": "git@github.com:scale-vector/rasa_bot_experiments.git" }) - runner.CONFIG.PACKAGE_REPOSITORY_SSH_KEY = load_secret("DEPLOY_KEY") + runner.CONFIG.package_repository_ssh_key = load_secret("DEPLOY_KEY") with pytest.raises(GitCommandError): runner.run(None) @@ -104,12 +102,12 @@ def test_reinitialize_package() -> None: setup_runner(DEST_SCHEMA_PREFIX) runner.ensure_newest_package() # mod the package - readme_path = modify_and_commit_file(runner.repo_path, "README.md", content=runner.CONFIG.DEST_SCHEMA_PREFIX) + readme_path = modify_and_commit_file(runner.repo_path, "README.md", content=runner.CONFIG.dest_schema_prefix) assert runner.storage.has_file(readme_path) # this will wipe out old package and clone again runner.ensure_newest_package() # we have old file back - assert runner.storage.load(f"{runner.CLONED_PACKAGE_NAME}/README.md") != runner.CONFIG.DEST_SCHEMA_PREFIX + assert runner.storage.load(f"{runner.CLONED_PACKAGE_NAME}/README.md") != runner.CONFIG.dest_schema_prefix def test_dbt_test_no_raw_schema() -> None: diff --git a/tests/dbt_runner/utils.py b/tests/dbt_runner/utils.py index 304984b93f..4dabd56957 100644 --- a/tests/dbt_runner/utils.py +++ b/tests/dbt_runner/utils.py @@ -48,8 +48,8 @@ def setup_runner(dest_schema_prefix: str, override_values: StrAny = None) -> Non clean_test_storage() C = gen_configuration_variant(initial_values=override_values) # set unique dest schema prefix by default - C.DEST_SCHEMA_PREFIX = dest_schema_prefix - C.PACKAGE_RUN_PARAMS = ["--fail-fast", "--full-refresh"] + C.dest_schema_prefix = dest_schema_prefix + C.package_run_params = ["--fail-fast", "--full-refresh"] # override values including the defaults above if override_values: for k,v in override_values.items(): diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index 3d33eba9ee..8a6a3adfe4 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -67,7 +67,7 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) @pytest.mark.parametrize('location', ["US", "EU"]) def test_bigquery_location(location: str, file_storage: FileStorage) -> None: - with cm_yield_client_with_storage("bigquery", initial_values={"LOCATION": location}) as client: + with cm_yield_client_with_storage("bigquery", initial_values={"location": location}) as client: user_table_name = prepare_table(client) load_json = { "_dlt_id": uniq_id(), diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 5d3fdacf57..34d44a1aa5 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -20,18 +20,18 @@ def schema() -> Schema: def test_configuration() -> None: # check names normalized with custom_environ({"GCP__PRIVATE_KEY": "---NO NEWLINE---\n"}): - C = make_configuration(GcpClientCredentials, GcpClientCredentials) - assert C.PRIVATE_KEY == "---NO NEWLINE---\n" + C = make_configuration(GcpClientCredentials()) + assert C.private_key == "---NO NEWLINE---\n" with custom_environ({"GCP__PRIVATE_KEY": "---WITH NEWLINE---\n"}): - C = make_configuration(GcpClientCredentials, GcpClientCredentials) - assert C.PRIVATE_KEY == "---WITH NEWLINE---\n" + C = make_configuration(GcpClientCredentials()) + assert C.private_key == "---WITH NEWLINE---\n" @pytest.fixture def gcp_client(schema: Schema) -> BigQueryClient: # return client without opening connection - BigQueryClient.configure(initial_values={"DEFAULT_DATASET": uniq_id()}) + BigQueryClient.configure(initial_values={"default_dataset": uniq_id()}) return BigQueryClient(schema) diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index edcf00873b..b273b1dbec 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -20,16 +20,16 @@ def schema() -> Schema: @pytest.fixture def client(schema: Schema) -> RedshiftClient: # return client without opening connection - RedshiftClient.configure(initial_values={"DEFAULT_DATASET": "TEST" + uniq_id()}) + RedshiftClient.configure(initial_values={"default_dataset": "TEST" + uniq_id()}) return RedshiftClient(schema) def test_configuration() -> None: # check names normalized with custom_environ({"PG__DBNAME": "UPPER_CASE_DATABASE", "PG__PASSWORD": " pass\n"}): - C = make_configuration(PostgresCredentials, PostgresCredentials) - assert C.DBNAME == "upper_case_database" - assert C.PASSWORD == "pass" + C = make_configuration(PostgresCredentials()) + assert C.dbname == "upper_case_database" + assert C.password == "pass" def test_create_table(client: RedshiftClient) -> None: diff --git a/tests/load/test_client.py b/tests/load/test_client.py index e547605027..359d44a24d 100644 --- a/tests/load/test_client.py +++ b/tests/load/test_client.py @@ -338,9 +338,9 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No @pytest.mark.parametrize('client_type', ALL_CLIENT_TYPES) def test_default_schema_name_init_storage(client_type: str) -> None: with cm_yield_client_with_storage(client_type, initial_values={ - "DEFAULT_SCHEMA_NAME": "event" # pass the schema that is a default schema. that should create dataset with the name `DEFAULT_DATASET` + "default_schema_name": "event" # pass the schema that is a default schema. that should create dataset with the name `default_dataset` }) as client: - assert client.sql_client.default_dataset_name == client.CONFIG.DEFAULT_DATASET + assert client.sql_client.default_dataset_name == client.CONFIG.default_dataset def prepare_schema(client: SqlJobClientBase, case: str) -> None: diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index e1be8b6243..404e5b46bf 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -15,7 +15,7 @@ from dlt.common.utils import uniq_id from dlt.load.client_base import JobClientBase, LoadEmptyJob, LoadJob -from dlt.load.configuration import configuration, ProductionLoaderConfiguration, LoaderConfiguration +from dlt.load.configuration import configuration, LoaderConfiguration from dlt.load.dummy import client from dlt.load import Load, __version__ from dlt.load.dummy.configuration import DummyClientConfiguration @@ -41,14 +41,10 @@ def logger_autouse() -> None: def test_gen_configuration() -> None: load = setup_loader() - assert ProductionLoaderConfiguration not in load.CONFIG.mro() - assert LoaderConfiguration in load.CONFIG.mro() - # for production config - with patch.dict(environ, {"IS_DEVELOPMENT_CONFIG": "False"}): - # mock missing config values - load = setup_loader(initial_values={"LOAD_VOLUME_PATH": LoaderConfiguration.LOAD_VOLUME_PATH}) - assert ProductionLoaderConfiguration in load.CONFIG.mro() - assert LoaderConfiguration in load.CONFIG.mro() + assert LoaderConfiguration in type(load.CONFIG).mro() + # mock missing config values + load = setup_loader(initial_values={"load_volume_path": LoaderConfiguration.load_volume_path}) + assert LoaderConfiguration in type(load.CONFIG).mro() def test_spool_job_started() -> None: @@ -100,7 +96,7 @@ def test_unsupported_write_disposition() -> None: def test_spool_job_failed() -> None: # this config fails job on start - load = setup_loader(initial_client_values={"FAIL_PROB" : 1.0}) + load = setup_loader(initial_client_values={"fail_prob" : 1.0}) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -125,7 +121,7 @@ def test_spool_job_failed() -> None: def test_spool_job_retry_new() -> None: # this config retries job on start (transient fail) - load = setup_loader(initial_client_values={"RETRY_PROB" : 1.0}) + load = setup_loader(initial_client_values={"retry_prob" : 1.0}) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -212,27 +208,27 @@ def test_try_retrieve_job() -> None: def test_completed_loop() -> None: - load = setup_loader(initial_client_values={"COMPLETED_PROB": 1.0}) + load = setup_loader(initial_client_values={"completed_prob": 1.0}) assert_complete_job(load, load.load_storage.storage) def test_failed_loop() -> None: # ask to delete completed - load = setup_loader(initial_values={"DELETE_COMPLETED_JOBS": True}, initial_client_values={"FAIL_PROB": 1.0}) + load = setup_loader(initial_values={"delete_completed_jobs": True}, initial_client_values={"fail_prob": 1.0}) # actually not deleted because one of the jobs failed assert_complete_job(load, load.load_storage.storage, should_delete_completed=False) def test_completed_loop_with_delete_completed() -> None: - load = setup_loader(initial_client_values={"COMPLETED_PROB": 1.0}) - load.CONFIG.DELETE_COMPLETED_JOBS = True + load = setup_loader(initial_client_values={"completed_prob": 1.0}) + load.CONFIG.delete_completed_jobs = True load.load_storage = load.create_storage(is_storage_owner=False) assert_complete_job(load, load.load_storage.storage, should_delete_completed=True) def test_retry_on_new_loop() -> None: # test job that retries sitting in new jobs - load = setup_loader(initial_client_values={"RETRY_PROB" : 1.0}) + load = setup_loader(initial_client_values={"retry_prob" : 1.0}) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -248,7 +244,7 @@ def test_retry_on_new_loop() -> None: files = load.load_storage.list_new_jobs(load_id) assert len(files) == 2 # jobs will be completed - load = setup_loader(initial_client_values={"COMPLETED_PROB" : 1.0}) + load = setup_loader(initial_client_values={"completed_prob" : 1.0}) load.run(ThreadPool()) files = load.load_storage.list_new_jobs(load_id) assert len(files) == 0 @@ -283,7 +279,7 @@ def test_exceptions() -> None: def test_version() -> None: - assert configuration({"CLIENT_TYPE": "dummy"})._VERSION == __version__ + assert configuration({"client_type": "dummy"})._version == __version__ def assert_complete_job(load: Load, storage: FileStorage, should_delete_completed: bool = False) -> None: @@ -330,11 +326,11 @@ def setup_loader(initial_values: StrAny = None, initial_client_values: StrAny = client.JOBS = {} default_values = { - "CLIENT_TYPE": "dummy", - "DELETE_COMPLETED_JOBS": False + "client_type": "dummy", + "delete_completed_jobs": False } default_client_values = { - "LOADER_FILE_FORMAT": "jsonl" + "loader_file_format": "jsonl" } if initial_values: default_values.update(initial_values) diff --git a/tests/load/utils.py b/tests/load/utils.py index 8ad489e51c..29a3c4c99b 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -110,12 +110,12 @@ def yield_client_with_storage(client_type: str, initial_values: StrAny = None) - os.environ.pop("DEFAULT_DATASET", None) # create dataset with random name default_dataset = "test_" + uniq_id() - client_initial_values = {"DEFAULT_DATASET": default_dataset} + client_initial_values = {"default_dataset": default_dataset} if initial_values is not None: client_initial_values.update(initial_values) # get event default schema - C = make_configuration(SchemaVolumeConfiguration, SchemaVolumeConfiguration, initial_values={ - "SCHEMA_VOLUME_PATH": "tests/common/cases/schemas/rasa" + C = make_configuration(SchemaVolumeConfiguration(), initial_value={ + "schema_volume_path": "tests/common/cases/schemas/rasa" }) schema_storage = SchemaStorage(C) schema = schema_storage.load_schema("event") diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index c68451edb8..4830dd32ea 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -37,10 +37,10 @@ def init_normalize(default_schemas_path: str = None) -> Normalize: clean_test_storage() initial = {} if default_schemas_path: - initial = {"IMPORT_SCHEMA_PATH": default_schemas_path, "EXTERNAL_SCHEMA_FORMAT": "json"} + initial = {"import_schema_path": default_schemas_path, "external_schema_format": "json"} n = Normalize(normalize_configuration(initial), CollectorRegistry()) # set jsonl as default writer - n.load_storage.loader_file_format = n.CONFIG.LOADER_FILE_FORMAT = "jsonl" + n.load_storage.loader_file_format = n.CONFIG.loader_file_format = "jsonl" return n @@ -75,7 +75,7 @@ def test_normalize_single_user_event_jsonl(raw_normalize: Normalize) -> None: def test_normalize_single_user_event_insert(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.loader_file_format = raw_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" + raw_normalize.load_storage.loader_file_format = raw_normalize.CONFIG.loader_file_format = "insert_values" expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) # verify values line for expected_table in expected_tables: @@ -131,7 +131,7 @@ def test_preserve_slot_complex_value_json_l(rasa_normalize: Normalize) -> None: def test_preserve_slot_complex_value_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" + rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.loader_file_format = "insert_values" load_id = normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_slot"], 2) @@ -154,7 +154,7 @@ def test_normalize_raw_type_hints(rasa_normalize: Normalize) -> None: def test_normalize_many_events_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" + rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.loader_file_format = "insert_values" load_id = normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) @@ -177,7 +177,7 @@ def test_normalize_many_events(rasa_normalize: Normalize) -> None: def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" + rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.loader_file_format = "insert_values" extract_cases( rasa_normalize.normalize_storage, ["event.event.many_load_2", "event.event.user_load_1", "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2"] @@ -206,7 +206,7 @@ def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: def test_normalize_typed_json(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.loader_file_format = raw_normalize.CONFIG.LOADER_FILE_FORMAT = "jsonl" + raw_normalize.load_storage.loader_file_format = raw_normalize.CONFIG.loader_file_format = "jsonl" extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], "special", "special") raw_normalize.run(ThreadPool(processes=1)) loads = raw_normalize.load_storage.list_packages() @@ -261,7 +261,7 @@ def normalize_cases(normalize: Normalize, cases: Sequence[str]) -> str: def extract_cases(normalize_storage: NormalizeStorage, cases: Sequence[str]) -> None: for case in cases: schema_name, table_name, _ = NormalizeStorage.parse_normalize_file_name(case + ".jsonl") - with open(json_case_path(case), "r") as f: + with open(json_case_path(case), "r", encoding="utf-8") as f: items = json.load(f) extract_items(normalize_storage, items, schema_name, table_name) @@ -296,4 +296,4 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType) def test_version() -> None: - assert normalize_configuration()._VERSION == __version__ + assert normalize_configuration()._version == __version__ diff --git a/tests/utils.py b/tests/utils.py index aa697fe481..b2efc4000f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,7 @@ import logging from os import environ -from dlt.common.configuration.utils import _get_config_attrs_with_hints, make_configuration +from dlt.common.configuration.utils import _get_resolvable_fields, make_configuration from dlt.common.configuration import RunConfiguration from dlt.common.logger import init_logging_from_config from dlt.common.file_storage import FileStorage @@ -55,10 +55,10 @@ def preserve_environ() -> None: environ.update(saved_environ) -def init_logger(C: Type[RunConfiguration] = None) -> None: +def init_logger(C: RunConfiguration = None) -> None: if not hasattr(logging, "health"): if not C: - C = make_configuration(RunConfiguration, RunConfiguration) + C = make_configuration(RunConfiguration()) init_logging_from_config(C) @@ -77,15 +77,15 @@ def clean_test_storage(init_normalize: bool = False, init_loader: bool = False) return storage -def add_config_to_env(config: Type[RunConfiguration]) -> None: +def add_config_to_env(config: RunConfiguration) -> None: # write back default values in configuration back into environment - possible_attrs = _get_config_attrs_with_hints(config).keys() + possible_attrs = _get_resolvable_fields(config).keys() for attr in possible_attrs: - if attr not in environ: + env_key = attr.upper() + if env_key not in environ: v = getattr(config, attr) if v is not None: - # print(f"setting {attr} to {v}") - environ[attr] = str(v) + environ[env_key] = str(v) def create_schema_with_name(schema_name) -> Schema: From 277f8e46cbd35c165a1a2ee78c36b9e7858ec264 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 28 Sep 2022 23:50:36 +0200 Subject: [PATCH 20/66] adds init values and embedded config tests --- .../configuration/base_configuration.py | 22 +++-- dlt/common/configuration/exceptions.py | 2 +- dlt/common/configuration/utils.py | 6 -- tests/common/test_configuration.py | 95 +++++++++++++++++++ 4 files changed, 109 insertions(+), 16 deletions(-) diff --git a/dlt/common/configuration/base_configuration.py b/dlt/common/configuration/base_configuration.py index 7b92a4aa57..35b48d142b 100644 --- a/dlt/common/configuration/base_configuration.py +++ b/dlt/common/configuration/base_configuration.py @@ -7,7 +7,7 @@ TDtcField = dataclasses.Field from dlt.common.typing import TAny -from dlt.common.configuration.exceptions import ConfigFieldTypingMissingException +from dlt.common.configuration.exceptions import ConfigFieldMissingAnnotationException def configspec(cls: Type[TAny] = None, /, *, init: bool = False) -> Type[TAny]: @@ -15,12 +15,14 @@ def configspec(cls: Type[TAny] = None, /, *, init: bool = False) -> Type[TAny]: def wrap(cls: Type[TAny]) -> Type[TAny]: # get all annotations without corresponding attributes and set them to None for ann in cls.__annotations__: - if not hasattr(cls, ann): + if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_impl")): setattr(cls, ann, None) # get all attributes without corresponding annotations for att_name, att in cls.__dict__.items(): if not callable(att) and not att_name.startswith(("__", "_abc_impl")) and att_name not in cls.__annotations__: - raise ConfigFieldTypingMissingException(att_name, cls) + print(att) + print(callable(att)) + raise ConfigFieldMissingAnnotationException(att_name, cls) return dataclasses.dataclass(cls, init=init, eq=False) # type: ignore # called with parenthesis @@ -37,7 +39,6 @@ class BaseConfiguration(MutableMapping[str, Any]): __is_partial__: bool = dataclasses.field(default = True, init=False, repr=False) # namespace used by config providers when searching for keys __namespace__: str = dataclasses.field(default = None, init=False, repr=False) - __dataclass_fields__: Dict[str, TDtcField] def __init__(self) -> None: self.__ignore_set_unknown_keys = False @@ -69,13 +70,13 @@ def to_native_representation(self) -> Any: # implement dictionary-compatible interface on top of dataclass def __getitem__(self, __key: str) -> Any: - if self._has_attr(__key): + if self.__has_attr(__key): return getattr(self, __key) else: raise KeyError(__key) def __setitem__(self, __key: str, __value: Any) -> None: - if self._has_attr(__key): + if self.__has_attr(__key): setattr(self, __key, __value) else: if not self.__ignore_set_unknown_keys: @@ -85,7 +86,7 @@ def __delitem__(self, __key: str) -> None: raise NotImplementedError("Configuration fields cannot be deleted") def __iter__(self) -> Iterator[str]: - return filter(lambda k: not k.startswith("__"), self.__dataclass_fields__.__iter__()) + return filter(lambda k: not k.startswith("__"), self.__fields_dict().__iter__()) def __len__(self) -> int: return sum(1 for _ in self.__iter__()) @@ -99,8 +100,11 @@ def update(self, other: Any = (), /, **kwds: Any) -> None: # helper functions - def _has_attr(self, __key: str) -> bool: - return __key in self.__dataclass_fields__ and not __key.startswith("__") + def __has_attr(self, __key: str) -> bool: + return __key in self.__fields_dict() and not __key.startswith("__") + + def __fields_dict(self) -> Dict[str, TDtcField]: + return self.__dataclass_fields__ # type: ignore @configspec diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index a1ee0c7bce..d15de37a33 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -53,7 +53,7 @@ def __init__(self, path: str) -> None: super().__init__(f"Missing config file in {path}") -class ConfigFieldTypingMissingException(ConfigurationException): +class ConfigFieldMissingAnnotationException(ConfigurationException): """thrown when configuration specification does not have type annotation""" def __init__(self, field_name: str, typ_: Type[Any]) -> None: diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py index 0c5f86698f..6be8892006 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/utils.py @@ -48,12 +48,6 @@ def make_configuration(config: TConfiguration, initial_value: Any = None, accept return config -# def is_direct_descendant(child: Type[Any], base: Type[Any]) -> bool: -# # TODO: there may be faster way to get direct descendant that mro -# # note: at index zero there's child -# return base == type.mro(child)[1] - - def _add_module_version(config: TConfiguration) -> None: try: v = sys._getframe(1).f_back.f_globals["__version__"] diff --git a/tests/common/test_configuration.py b/tests/common/test_configuration.py index 838e8fb846..97b0c486b0 100644 --- a/tests/common/test_configuration.py +++ b/tests/common/test_configuration.py @@ -8,6 +8,7 @@ ConfigEnvValueCannotBeCoercedException, BaseConfiguration, utils, configspec) from dlt.common.configuration.utils import make_configuration from dlt.common.configuration.providers import environ as environ_provider +from dlt.common.utils import custom_environ from tests.utils import preserve_environ @@ -148,6 +149,35 @@ class FieldWithNoDefaultConfiguration(RunConfiguration): no_default: str +@configspec(init=True) +class InstrumentedConfiguration(BaseConfiguration): + head: str + tube: List[str] + heels: str + + def to_native_representation(self) -> Any: + return self.head + ">" + ">".join(self.tube) + ">" + self.heels + + def from_native_representation(self, native_value: Any) -> None: + if not isinstance(native_value, str): + raise ValueError(native_value) + parts = native_value.split(">") + self.head = parts[0] + self.heels = parts[-1] + self.tube = parts[1:-1] + + def check_integrity(self) -> None: + if self.head > self.heels: + raise RuntimeError("Head over heels") + + +@configspec +class EmbeddedConfiguration(BaseConfiguration): + default: str + instrumented: InstrumentedConfiguration + namespaced: NamespacedConfiguration + + LongInteger = NewType("LongInteger", int) FirstOrderStr = NewType("FirstOrderStr", str) SecondOrderStr = NewType("SecondOrderStr", FirstOrderStr) @@ -159,6 +189,71 @@ def environment() -> Any: return environ +def test_initial_config_value() -> None: + # set from init method + C = make_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) + assert C.to_native_representation() == "h>a>b>he" + # set from native form + C = make_configuration(InstrumentedConfiguration(), initial_value="h>a>b>he") + assert C.head == "h" + assert C.tube == ["a", "b"] + assert C.heels == "he" + # set from dictionary + C = make_configuration(InstrumentedConfiguration(), initial_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) + assert C.to_native_representation() == "h>tu>be>xhe" + + +def test_check_integrity() -> None: + with pytest.raises(RuntimeError): + # head over hells + make_configuration(InstrumentedConfiguration(), initial_value="he>a>b>h") + + +def test_embedded_config(environment: Any) -> None: + # resolve all embedded config, using initial value for instrumented config and initial dict for namespaced config + C = make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "instrumented": "h>tu>be>xhe", "namespaced": {"password": "pwd"}}) + assert C.default == "set" + assert C.instrumented.to_native_representation() == "h>tu>be>xhe" + assert C.namespaced.password == "pwd" + + # resolve but providing values via env + with custom_environ({"INSTRUMENTED": "h>tu>u>be>xhe", "DLT_TEST__PASSWORD": "passwd", "DEFAULT": "DEF"}): + C = make_configuration(EmbeddedConfiguration()) + assert C.default == "DEF" + assert C.instrumented.to_native_representation() == "h>tu>u>be>xhe" + assert C.namespaced.password == "passwd" + + # resolve partial, partial is passed to embedded + C = make_configuration(EmbeddedConfiguration(), accept_partial=True) + assert C.__is_partial__ + assert C.namespaced.__is_partial__ + assert C.instrumented.__is_partial__ + + # some are partial, some are not + with custom_environ({"DLT_TEST__PASSWORD": "passwd"}): + C = make_configuration(EmbeddedConfiguration(), accept_partial=True) + assert C.__is_partial__ + assert not C.namespaced.__is_partial__ + assert C.instrumented.__is_partial__ + + # single integrity error fails all the embeds + with custom_environ({"INSTRUMENTED": "he>tu>u>be>h"}): + with pytest.raises(RuntimeError): + make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) + + # part via env part via initial values + with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): + C = make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) + assert C.instrumented.to_native_representation() == "h>tu>u>be>he" + + +def test_provider_values_over_initial(environment: Any) -> None: + with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): + C = make_configuration(EmbeddedConfiguration(), initial_value={"instrumented": "h>tu>be>xhe"}, accept_partial=True) + assert C.instrumented.to_native_representation() == "h>tu>u>be>he" + assert not C.instrumented.__is_partial__ + + def test_run_configuration_gen_name(environment: Any) -> None: C = make_configuration(RunConfiguration()) assert C.pipeline_name.startswith("dlt_") From 400ca1aafbe7d00ca6548e42fc1b1a2a59b383fe Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 30 Sep 2022 15:51:22 +0200 Subject: [PATCH 21/66] reorganizes configuration module --- dlt/common/configuration/__init__.py | 18 ++--- .../configuration/{utils.py => resolve.py} | 12 +-- dlt/common/configuration/specs/__init__.py | 0 .../{ => specs}/base_configuration.py | 0 .../{ => specs}/gcp_client_credentials.py | 5 +- .../{ => specs}/load_volume_configuration.py | 2 +- .../normalize_volume_configuration.py | 2 +- .../{ => specs}/pool_runner_configuration.py | 3 +- .../{ => specs}/postgres_credentials.py | 2 +- .../{ => specs}/run_configuration.py | 2 +- .../schema_volume_configuration.py | 2 +- dlt/common/runners/init.py | 2 +- dlt/common/runners/runnable.py | 7 +- dlt/common/storages/schema_storage.py | 2 +- dlt/load/bigquery/client.py | 2 +- dlt/load/redshift/client.py | 2 +- experiments/pipeline/configuration.py | 2 +- tests/common/runners/test_runners.py | 2 +- tests/common/test_configuration.py | 74 +++++++++---------- tests/dbt_runner/test_runner_redshift.py | 2 +- tests/load/utils.py | 2 +- tests/utils.py | 2 +- 22 files changed, 73 insertions(+), 74 deletions(-) rename dlt/common/configuration/{utils.py => resolve.py} (92%) create mode 100644 dlt/common/configuration/specs/__init__.py rename dlt/common/configuration/{ => specs}/base_configuration.py (100%) rename dlt/common/configuration/{ => specs}/gcp_client_credentials.py (84%) rename dlt/common/configuration/{ => specs}/load_volume_configuration.py (75%) rename dlt/common/configuration/{ => specs}/normalize_volume_configuration.py (64%) rename dlt/common/configuration/{ => specs}/pool_runner_configuration.py (86%) rename dlt/common/configuration/{ => specs}/postgres_credentials.py (88%) rename dlt/common/configuration/{ => specs}/run_configuration.py (94%) rename dlt/common/configuration/{ => specs}/schema_volume_configuration.py (86%) diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index aefd544548..47374436b3 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,12 +1,12 @@ -from .run_configuration import RunConfiguration # noqa: F401 -from .base_configuration import BaseConfiguration, CredentialsConfiguration, configspec # noqa: F401 -from .normalize_volume_configuration import NormalizeVolumeConfiguration # noqa: F401 -from .load_volume_configuration import LoadVolumeConfiguration # noqa: F401 -from .schema_volume_configuration import SchemaVolumeConfiguration # noqa: F401 -from .pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 -from .gcp_client_credentials import GcpClientCredentials # noqa: F401 -from .postgres_credentials import PostgresCredentials # noqa: F401 -from .utils import make_configuration # noqa: F401 +from .specs.run_configuration import RunConfiguration # noqa: F401 +from .specs.base_configuration import BaseConfiguration, CredentialsConfiguration, configspec # noqa: F401 +from .specs.normalize_volume_configuration import NormalizeVolumeConfiguration # noqa: F401 +from .specs.load_volume_configuration import LoadVolumeConfiguration # noqa: F401 +from .specs.schema_volume_configuration import SchemaVolumeConfiguration # noqa: F401 +from .specs.pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 +from .specs.gcp_client_credentials import GcpClientCredentials # noqa: F401 +from .specs.postgres_credentials import PostgresCredentials # noqa: F401 +from .resolve import make_configuration # noqa: F401 from .exceptions import ( # noqa: F401 ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, ConfigIntegrityException, ConfigFileNotFoundException) diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/resolve.py similarity index 92% rename from dlt/common/configuration/utils.py rename to dlt/common/configuration/resolve.py index 6be8892006..36f8f7290b 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/resolve.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Mapping, Type, TypeVar from dlt.common.typing import is_optional_type, is_literal_type -from dlt.common.configuration import BaseConfiguration +from dlt.common.configuration.specs.base_configuration import BaseConfiguration from dlt.common.configuration.providers import environ from dlt.common.configuration.exceptions import (ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException) @@ -48,7 +48,7 @@ def make_configuration(config: TConfiguration, initial_value: Any = None, accept return config -def _add_module_version(config: TConfiguration) -> None: +def _add_module_version(config: BaseConfiguration) -> None: try: v = sys._getframe(1).f_back.f_globals["__version__"] semver.VersionInfo.parse(v) @@ -57,7 +57,7 @@ def _add_module_version(config: TConfiguration) -> None: pass -def _resolve_config_fields(config: TConfiguration, fields: Mapping[str, type], accept_partial: bool) -> None: +def _resolve_config_fields(config: BaseConfiguration, fields: Mapping[str, type], accept_partial: bool) -> None: for key, hint in fields.items(): # get default value resolved_value = getattr(config, key, None) @@ -104,7 +104,7 @@ def _coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: raise ConfigEnvValueCannotBeCoercedException(key, value, hint) from exc -def _is_config_bounded(config: TConfiguration, fields: Mapping[str, type]) -> None: +def _is_config_bounded(config: BaseConfiguration, fields: Mapping[str, type]) -> None: # TODO: here we assume all keys are taken from environ provider, that should change when we introduce more providers _unbound_attrs = [ environ.get_key_name(key, config.__namespace__) for key in fields if getattr(config, key) is None and not is_optional_type(fields[key]) @@ -114,7 +114,7 @@ def _is_config_bounded(config: TConfiguration, fields: Mapping[str, type]) -> No raise ConfigEntryMissingException(_unbound_attrs, config.__namespace__) -def _check_configuration_integrity(config: TConfiguration) -> None: +def _check_configuration_integrity(config: BaseConfiguration) -> None: # python multi-inheritance is cooperative and this would require that all configurations cooperatively # call each other check_integrity. this is not at all possible as we do not know which configs in the end will # be mixed together. @@ -128,7 +128,7 @@ def _check_configuration_integrity(config: TConfiguration) -> None: c.__dict__[CHECK_INTEGRITY_F](config) -def _get_resolvable_fields(config: TConfiguration) -> Dict[str, type]: +def _get_resolvable_fields(config: BaseConfiguration) -> Dict[str, type]: return {f.name:f.type for f in dataclasses.fields(config) if not f.name.startswith("__")} diff --git a/dlt/common/configuration/specs/__init__.py b/dlt/common/configuration/specs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dlt/common/configuration/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py similarity index 100% rename from dlt/common/configuration/base_configuration.py rename to dlt/common/configuration/specs/base_configuration.py diff --git a/dlt/common/configuration/gcp_client_credentials.py b/dlt/common/configuration/specs/gcp_client_credentials.py similarity index 84% rename from dlt/common/configuration/gcp_client_credentials.py rename to dlt/common/configuration/specs/gcp_client_credentials.py index 0d42aad990..d51921d7a4 100644 --- a/dlt/common/configuration/gcp_client_credentials.py +++ b/dlt/common/configuration/specs/gcp_client_credentials.py @@ -2,7 +2,7 @@ from dlt.common import json from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration.base_configuration import CredentialsConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec @configspec @@ -18,7 +18,8 @@ class GcpClientCredentials(CredentialsConfiguration): client_email: str = None http_timeout: float = 15.0 - retry_deadline: float = 600 + file_upload_timeout: float = 30 * 60.0 + retry_deadline: float = 600 # how long to retry the operation in case of error, the backoff 60s def from_native_representation(self, initial_value: Any) -> None: if not isinstance(initial_value, str): diff --git a/dlt/common/configuration/load_volume_configuration.py b/dlt/common/configuration/specs/load_volume_configuration.py similarity index 75% rename from dlt/common/configuration/load_volume_configuration.py rename to dlt/common/configuration/specs/load_volume_configuration.py index 9f2fdada3d..b626622a24 100644 --- a/dlt/common/configuration/load_volume_configuration.py +++ b/dlt/common/configuration/specs/load_volume_configuration.py @@ -1,4 +1,4 @@ -from dlt.common.configuration.base_configuration import BaseConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec @configspec diff --git a/dlt/common/configuration/normalize_volume_configuration.py b/dlt/common/configuration/specs/normalize_volume_configuration.py similarity index 64% rename from dlt/common/configuration/normalize_volume_configuration.py rename to dlt/common/configuration/specs/normalize_volume_configuration.py index 12ff684c54..584f271169 100644 --- a/dlt/common/configuration/normalize_volume_configuration.py +++ b/dlt/common/configuration/specs/normalize_volume_configuration.py @@ -1,4 +1,4 @@ -from dlt.common.configuration.base_configuration import BaseConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec @configspec diff --git a/dlt/common/configuration/pool_runner_configuration.py b/dlt/common/configuration/specs/pool_runner_configuration.py similarity index 86% rename from dlt/common/configuration/pool_runner_configuration.py rename to dlt/common/configuration/specs/pool_runner_configuration.py index e5ef5665d8..3e7962ed43 100644 --- a/dlt/common/configuration/pool_runner_configuration.py +++ b/dlt/common/configuration/specs/pool_runner_configuration.py @@ -1,6 +1,7 @@ from typing import Literal, Optional -from dlt.common.configuration.run_configuration import RunConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import configspec +from dlt.common.configuration.specs.run_configuration import RunConfiguration TPoolType = Literal["process", "thread", "none"] diff --git a/dlt/common/configuration/postgres_credentials.py b/dlt/common/configuration/specs/postgres_credentials.py similarity index 88% rename from dlt/common/configuration/postgres_credentials.py rename to dlt/common/configuration/specs/postgres_credentials.py index 62b639eac8..cc745de284 100644 --- a/dlt/common/configuration/postgres_credentials.py +++ b/dlt/common/configuration/specs/postgres_credentials.py @@ -1,7 +1,7 @@ from typing import Any from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration.base_configuration import CredentialsConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec @configspec diff --git a/dlt/common/configuration/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py similarity index 94% rename from dlt/common/configuration/run_configuration.py rename to dlt/common/configuration/specs/run_configuration.py index 88af7699d1..7e4c620b65 100644 --- a/dlt/common/configuration/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple, IO from dlt.common.utils import encoding_for_mode, entry_point_file_stem -from dlt.common.configuration.base_configuration import BaseConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec from dlt.common.configuration.exceptions import ConfigFileNotFoundException diff --git a/dlt/common/configuration/schema_volume_configuration.py b/dlt/common/configuration/specs/schema_volume_configuration.py similarity index 86% rename from dlt/common/configuration/schema_volume_configuration.py rename to dlt/common/configuration/specs/schema_volume_configuration.py index 6e9677ea14..3b1d8c4df9 100644 --- a/dlt/common/configuration/schema_volume_configuration.py +++ b/dlt/common/configuration/specs/schema_volume_configuration.py @@ -1,6 +1,6 @@ from typing import Optional, Literal -from dlt.common.configuration.base_configuration import BaseConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec TSchemaFileFormat = Literal["json", "yaml"] diff --git a/dlt/common/runners/init.py b/dlt/common/runners/init.py index 1a5025374e..f53b08962f 100644 --- a/dlt/common/runners/init.py +++ b/dlt/common/runners/init.py @@ -2,7 +2,7 @@ from typing import Type from dlt.common import logger -from dlt.common.configuration.run_configuration import RunConfiguration +from dlt.common.configuration import RunConfiguration from dlt.common.logger import init_logging_from_config, init_telemetry from dlt.common.signals import register_signals diff --git a/dlt/common/runners/runnable.py b/dlt/common/runners/runnable.py index a4e775c0bb..c3fce997d6 100644 --- a/dlt/common/runners/runnable.py +++ b/dlt/common/runners/runnable.py @@ -1,13 +1,10 @@ -import inspect from abc import ABC, abstractmethod from functools import wraps -from typing import Any, Dict, Mapping, Type, TypeVar, TYPE_CHECKING, Union, Generic, get_args +from typing import Any, Dict, Type, TypeVar, TYPE_CHECKING, Union, Generic from multiprocessing.pool import Pool from weakref import WeakValueDictionary -from dlt.common.configuration.run_configuration import BaseConfiguration -from dlt.common.typing import StrAny, TFun -from dlt.common.utils import uniq_id +from dlt.common.typing import TFun from dlt.common.telemetry import TRunMetrics TPool = TypeVar("TPool", bound=Pool) diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index bfe1912312..6393e6ab52 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -4,7 +4,7 @@ from typing import Iterator, List, Mapping from dlt.common import json, logger -from dlt.common.configuration.schema_volume_configuration import TSchemaFileFormat +from dlt.common.configuration.specs.schema_volume_configuration import TSchemaFileFormat from dlt.common.file_storage import FileStorage from dlt.common.schema import Schema, verify_schema_hash from dlt.common.typing import DictStrAny diff --git a/dlt/load/bigquery/client.py b/dlt/load/bigquery/client.py index 20dfe7b4df..fae2104d81 100644 --- a/dlt/load/bigquery/client.py +++ b/dlt/load/bigquery/client.py @@ -333,7 +333,7 @@ def _create_load_job(self, table_name: str, write_disposition: TWriteDisposition self.sql_client.make_qualified_table_name(table_name), job_id=job_id, job_config=job_config, - timeout=self.CREDENTIALS.http_timeout + timeout=self.CREDENTIALS.file_upload_timeout ) def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: diff --git a/dlt/load/redshift/client.py b/dlt/load/redshift/client.py index 889615c241..c35796e384 100644 --- a/dlt/load/redshift/client.py +++ b/dlt/load/redshift/client.py @@ -10,7 +10,7 @@ from contextlib import contextmanager from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple -from dlt.common.configuration.postgres_credentials import PostgresCredentials +from dlt.common.configuration.specs.postgres_credentials import PostgresCredentials from dlt.common.typing import StrAny from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE diff --git a/experiments/pipeline/configuration.py b/experiments/pipeline/configuration.py index 64101b4139..4bfddc0ac4 100644 --- a/experiments/pipeline/configuration.py +++ b/experiments/pipeline/configuration.py @@ -9,7 +9,7 @@ from dlt.common.typing import DictStrAny, StrAny, TAny, TFun from dlt.common.configuration import BaseConfiguration -from dlt.common.configuration.utils import NON_EVAL_TYPES, make_configuration, SIMPLE_TYPES +from dlt.common.configuration.resolve import NON_EVAL_TYPES, make_configuration, SIMPLE_TYPES # _POS_PARAMETER_KINDS = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, Parameter.VAR_POSITIONAL) diff --git a/tests/common/runners/test_runners.py b/tests/common/runners/test_runners.py index 666839b9dc..29dcf7027f 100644 --- a/tests/common/runners/test_runners.py +++ b/tests/common/runners/test_runners.py @@ -6,7 +6,7 @@ from dlt.cli import TRunnerArgs from dlt.common import signals from dlt.common.configuration import PoolRunnerConfiguration, make_configuration, configspec -from dlt.common.configuration.pool_runner_configuration import TPoolType +from dlt.common.configuration.specs.pool_runner_configuration import TPoolType from dlt.common.exceptions import DltException, SignalReceivedException, TimeRangeExhaustedException, UnsupportedProcessStartMethodException from dlt.common.runners import pool_runner as runner diff --git a/tests/common/test_configuration.py b/tests/common/test_configuration.py index 97b0c486b0..c6b1b6c77a 100644 --- a/tests/common/test_configuration.py +++ b/tests/common/test_configuration.py @@ -5,8 +5,8 @@ from dlt.common.typing import TSecretValue from dlt.common.configuration import ( RunConfiguration, ConfigEntryMissingException, ConfigFileNotFoundException, - ConfigEnvValueCannotBeCoercedException, BaseConfiguration, utils, configspec) -from dlt.common.configuration.utils import make_configuration + ConfigEnvValueCannotBeCoercedException, BaseConfiguration, resolve, configspec) +from dlt.common.configuration.resolve import make_configuration from dlt.common.configuration.providers import environ as environ_provider from dlt.common.utils import custom_environ @@ -351,8 +351,8 @@ class MultiConfiguration(MockProdConfiguration, ConfigurationWithOptionalTypes, def test_raises_on_unresolved_fields() -> None: with pytest.raises(ConfigEntryMissingException) as config_entry_missing_exception: C = WrongConfiguration() - keys = utils._get_resolvable_fields(C) - utils._is_config_bounded(C, keys) + keys = resolve._get_resolvable_fields(C) + resolve._is_config_bounded(C, keys) assert 'NONECONFIGVAR' in config_entry_missing_exception.value.missing_set @@ -364,8 +364,8 @@ def test_raises_on_unresolved_fields() -> None: def test_optional_types_are_not_required() -> None: # this should not raise an exception - keys = utils._get_resolvable_fields(ConfigurationWithOptionalTypes()) - utils._is_config_bounded(ConfigurationWithOptionalTypes(), keys) + keys = resolve._get_resolvable_fields(ConfigurationWithOptionalTypes()) + resolve._is_config_bounded(ConfigurationWithOptionalTypes(), keys) # make optional config make_configuration(ConfigurationWithOptionalTypes()) # make config with optional values @@ -376,9 +376,9 @@ def test_configuration_apply_adds_environment_variable_to_config(environment: An environment["NONECONFIGVAR"] = "Some" C = WrongConfiguration() - keys = utils._get_resolvable_fields(C) - utils._resolve_config_fields(C, keys, accept_partial=False) - utils._is_config_bounded(C, keys) + keys = resolve._get_resolvable_fields(C) + resolve._resolve_config_fields(C, keys, accept_partial=False) + resolve._is_config_bounded(C, keys) assert C.NoneConfigVar == environment["NONECONFIGVAR"] @@ -387,16 +387,16 @@ def test_configuration_resolve_env_var(environment: Any) -> None: environment["TEST_BOOL"] = 'True' C = SimpleConfiguration() - keys = utils._get_resolvable_fields(C) - utils._resolve_config_fields(C, keys, accept_partial=False) - utils._is_config_bounded(C, keys) + keys = resolve._get_resolvable_fields(C) + resolve._resolve_config_fields(C, keys, accept_partial=False) + resolve._is_config_bounded(C, keys) # value will be coerced to bool assert C.test_bool is True def test_find_all_keys() -> None: - keys = utils._get_resolvable_fields(VeryWrongConfiguration()) + keys = resolve._get_resolvable_fields(VeryWrongConfiguration()) # assert hints and types: LOG_COLOR had it hint overwritten in derived class assert set({'str_val': str, 'int_val': int, 'NoneConfigVar': str, 'log_color': str}.items()).issubset(keys.items()) @@ -406,9 +406,9 @@ def test_coercions(environment: Any) -> None: environment[key.upper()] = str(value) C = CoercionTestConfiguration() - keys = utils._get_resolvable_fields(C) - utils._resolve_config_fields(C, keys, accept_partial=False) - utils._is_config_bounded(C, keys) + keys = resolve._get_resolvable_fields(C) + resolve._resolve_config_fields(C, keys, accept_partial=False) + resolve._is_config_bounded(C, keys) for key in COERCIONS: assert getattr(C, key) == COERCIONS[key] @@ -416,11 +416,11 @@ def test_coercions(environment: Any) -> None: def test_invalid_coercions(environment: Any) -> None: C = CoercionTestConfiguration() - config_keys = utils._get_resolvable_fields(C) + config_keys = resolve._get_resolvable_fields(C) for key, value in INVALID_COERCIONS.items(): try: environment[key.upper()] = str(value) - utils._resolve_config_fields(C, config_keys, accept_partial=False) + resolve._resolve_config_fields(C, config_keys, accept_partial=False) except ConfigEnvValueCannotBeCoercedException as coerc_exc: # must fail exactly on expected value if coerc_exc.attr_name != key: @@ -433,10 +433,10 @@ def test_invalid_coercions(environment: Any) -> None: def test_excepted_coercions(environment: Any) -> None: C = CoercionTestConfiguration() - config_keys = utils._get_resolvable_fields(C) + config_keys = resolve._get_resolvable_fields(C) for k, v in EXCEPTED_COERCIONS.items(): environment[k.upper()] = str(v) - utils._resolve_config_fields(C, config_keys, accept_partial=False) + resolve._resolve_config_fields(C, config_keys, accept_partial=False) for key in EXCEPTED_COERCIONS: assert getattr(C, key) == COERCED_EXCEPTIONS[key] @@ -444,7 +444,7 @@ def test_excepted_coercions(environment: Any) -> None: def test_make_configuration(environment: Any) -> None: # fill up configuration environment["NONECONFIGVAR"] = "1" - C = utils.make_configuration(WrongConfiguration()) + C = resolve.make_configuration(WrongConfiguration()) assert not C.__is_partial__ assert C.NoneConfigVar == "1" @@ -452,7 +452,7 @@ def test_make_configuration(environment: Any) -> None: def test_auto_derivation(environment: Any) -> None: # make_configuration works on instances of dataclasses and types are not modified environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration()) + C = resolve.make_configuration(SecretConfiguration()) # auto derived type holds the value assert C.secret_value == "1" # base type is untouched @@ -491,11 +491,11 @@ def test_finds_version(environment: Any) -> None: global __version__ v = __version__ - C = utils.make_configuration(SimpleConfiguration()) + C = resolve.make_configuration(SimpleConfiguration()) assert C._version == v try: del globals()["__version__"] - C = utils.make_configuration(SimpleConfiguration()) + C = resolve.make_configuration(SimpleConfiguration()) assert not hasattr(C, "_version") finally: __version__ = v @@ -503,9 +503,9 @@ def test_finds_version(environment: Any) -> None: def test_secret(environment: Any) -> None: with pytest.raises(ConfigEntryMissingException): - utils.make_configuration(SecretConfiguration()) + resolve.make_configuration(SecretConfiguration()) environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration()) + C = resolve.make_configuration(SecretConfiguration()) assert C.secret_value == "1" # mock the path to point to secret storage # from dlt.common.configuration import config_utils @@ -514,18 +514,18 @@ def test_secret(environment: Any) -> None: try: # must read a secret file environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = utils.make_configuration(SecretConfiguration()) + C = resolve.make_configuration(SecretConfiguration()) assert C.secret_value == "BANANA" # set some weird path, no secret file at all del environment['SECRET_VALUE'] environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" with pytest.raises(ConfigEntryMissingException): - utils.make_configuration(SecretConfiguration()) + resolve.make_configuration(SecretConfiguration()) # set env which is a fallback for secret not as file environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration()) + C = resolve.make_configuration(SecretConfiguration()) assert C.secret_value == "1" finally: environ_provider.SECRET_STORAGE_PATH = path @@ -535,7 +535,7 @@ def test_secret_kube_fallback(environment: Any) -> None: path = environ_provider.SECRET_STORAGE_PATH try: environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = utils.make_configuration(SecretKubeConfiguration()) + C = resolve.make_configuration(SecretKubeConfiguration()) # all unix editors will add x10 at the end of file, it will be preserved assert C.secret_kube == "kube\n" # we propagate secrets back to environ and strip the whitespace @@ -571,7 +571,7 @@ def test_coerce_values() -> None: def test_configuration_files(environment: Any) -> None: # overwrite config file paths environment["CONFIG_FILES_STORAGE_PATH"] = "./tests/common/cases/schemas/ev1/%s" - C = utils.make_configuration(MockProdConfigurationVar()) + C = resolve.make_configuration(MockProdConfigurationVar()) assert C.config_files_storage_path == environment["CONFIG_FILES_STORAGE_PATH"] assert C.has_configuration_file("hasn't") is False assert C.has_configuration_file("event_schema.json") is True @@ -584,22 +584,22 @@ def test_configuration_files(environment: Any) -> None: def test_namespaced_configuration(environment: Any) -> None: with pytest.raises(ConfigEntryMissingException) as exc_val: - utils.make_configuration(NamespacedConfiguration()) + resolve.make_configuration(NamespacedConfiguration()) assert exc_val.value.missing_set == ["DLT_TEST__PASSWORD"] assert exc_val.value.namespace == "DLT_TEST" # init vars work without namespace - C = utils.make_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) + C = resolve.make_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) assert C.password == "PASS" # env var must be prefixed environment["PASSWORD"] = "PASS" with pytest.raises(ConfigEntryMissingException) as exc_val: - utils.make_configuration(NamespacedConfiguration()) + resolve.make_configuration(NamespacedConfiguration()) environment["DLT_TEST__PASSWORD"] = "PASS" - C = utils.make_configuration(NamespacedConfiguration()) + C = resolve.make_configuration(NamespacedConfiguration()) assert C.password == "PASS" def coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: - hint = utils._extract_simple_type(hint) - return utils._coerce_single_value(key, value, hint) + hint = resolve._extract_simple_type(hint) + return resolve._coerce_single_value(key, value, hint) diff --git a/tests/dbt_runner/test_runner_redshift.py b/tests/dbt_runner/test_runner_redshift.py index d5b5b88081..176cf2786a 100644 --- a/tests/dbt_runner/test_runner_redshift.py +++ b/tests/dbt_runner/test_runner_redshift.py @@ -5,7 +5,7 @@ from dlt.common import logger from dlt.common.configuration import PostgresCredentials -from dlt.common.configuration.utils import make_configuration +from dlt.common.configuration.resolve import make_configuration from dlt.common.file_storage import FileStorage from dlt.common.telemetry import TRunMetrics, get_metrics_from_prometheus from dlt.common.typing import StrStr diff --git a/tests/load/utils.py b/tests/load/utils.py index 29a3c4c99b..367ab55e34 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -4,7 +4,7 @@ from dlt.common import json, Decimal from dlt.common.configuration import make_configuration -from dlt.common.configuration.schema_volume_configuration import SchemaVolumeConfiguration +from dlt.common.configuration.specs.schema_volume_configuration import SchemaVolumeConfiguration from dlt.common.data_writers import DataWriter from dlt.common.file_storage import FileStorage from dlt.common.schema import TColumnSchema, TTableSchemaColumns diff --git a/tests/utils.py b/tests/utils.py index b2efc4000f..57610f43bf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,7 @@ import logging from os import environ -from dlt.common.configuration.utils import _get_resolvable_fields, make_configuration +from dlt.common.configuration.resolve import _get_resolvable_fields, make_configuration from dlt.common.configuration import RunConfiguration from dlt.common.logger import init_logging_from_config from dlt.common.file_storage import FileStorage From b53ed23bcf8b55ee5ac67ed853104b4d10201d57 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Oct 2022 15:24:04 +0200 Subject: [PATCH 22/66] moves configurations into specs folder --- dlt/common/configuration/__init__.py | 9 +--- dlt/common/configuration/specs/__init__.py | 8 ++++ .../configuration/specs/base_configuration.py | 6 +-- dlt/common/exceptions.py | 3 -- dlt/common/json.py | 30 ++++++------- dlt/common/logger.py | 10 +++-- dlt/common/runners/init.py | 3 +- dlt/common/runners/pool_runner.py | 2 +- dlt/common/schema/schema.py | 6 +-- dlt/common/signals.py | 6 +-- dlt/common/storages/exceptions.py | 5 ++- dlt/common/storages/live_schema_storage.py | 2 +- dlt/common/storages/load_storage.py | 2 +- dlt/common/storages/normalize_storage.py | 5 +-- dlt/common/storages/schema_storage.py | 3 +- dlt/common/storages/versioned_storage.py | 2 +- dlt/common/utils.py | 24 ++++++----- dlt/common/validation.py | 2 +- dlt/common/wei.py | 3 +- dlt/dbt_runner/configuration.py | 3 +- dlt/dbt_runner/runner.py | 24 +++++------ dlt/helpers/streamlit.py | 15 +++---- dlt/load/bigquery/client.py | 2 +- dlt/load/bigquery/configuration.py | 3 +- dlt/load/client_base.py | 2 +- dlt/load/configuration.py | 3 +- dlt/load/dummy/client.py | 5 +-- dlt/load/exceptions.py | 4 +- dlt/load/redshift/client.py | 2 +- dlt/load/redshift/configuration.py | 3 +- dlt/normalize/configuration.py | 5 +-- dlt/pipeline/pipeline.py | 3 +- experiments/pipeline/configuration.py | 2 +- tests/common/runners/test_runners.py | 4 +- tests/common/schema/test_schema.py | 3 +- tests/common/storages/test_loader_storage.py | 3 +- .../common/storages/test_normalize_storage.py | 7 ++-- tests/common/storages/test_schema_storage.py | 6 +-- tests/common/test_configuration.py | 42 +++++++++---------- tests/common/test_logging.py | 5 ++- tests/conftest.py | 2 +- tests/dbt_runner/test_runner_bigquery.py | 4 +- tests/dbt_runner/test_runner_redshift.py | 4 +- tests/dbt_runner/utils.py | 2 +- .../bigquery/test_bigquery_table_builder.py | 3 +- .../redshift/test_redshift_table_builder.py | 3 +- tests/tools/create_storages.py | 2 +- tests/utils.py | 6 +-- 48 files changed, 157 insertions(+), 146 deletions(-) diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 47374436b3..30f9a4a0d2 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,11 +1,4 @@ -from .specs.run_configuration import RunConfiguration # noqa: F401 -from .specs.base_configuration import BaseConfiguration, CredentialsConfiguration, configspec # noqa: F401 -from .specs.normalize_volume_configuration import NormalizeVolumeConfiguration # noqa: F401 -from .specs.load_volume_configuration import LoadVolumeConfiguration # noqa: F401 -from .specs.schema_volume_configuration import SchemaVolumeConfiguration # noqa: F401 -from .specs.pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 -from .specs.gcp_client_credentials import GcpClientCredentials # noqa: F401 -from .specs.postgres_credentials import PostgresCredentials # noqa: F401 +from .specs.base_configuration import configspec # noqa: F401 from .resolve import make_configuration # noqa: F401 from .exceptions import ( # noqa: F401 diff --git a/dlt/common/configuration/specs/__init__.py b/dlt/common/configuration/specs/__init__.py index e69de29bb2..f6c33bceb2 100644 --- a/dlt/common/configuration/specs/__init__.py +++ b/dlt/common/configuration/specs/__init__.py @@ -0,0 +1,8 @@ +from .run_configuration import RunConfiguration # noqa: F401 +from .base_configuration import BaseConfiguration, CredentialsConfiguration # noqa: F401 +from .normalize_volume_configuration import NormalizeVolumeConfiguration # noqa: F401 +from .load_volume_configuration import LoadVolumeConfiguration # noqa: F401 +from .schema_volume_configuration import SchemaVolumeConfiguration, TSchemaFileFormat # noqa: F401 +from .pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 +from .gcp_client_credentials import GcpClientCredentials # noqa: F401 +from .postgres_credentials import PostgresCredentials # noqa: F401 \ No newline at end of file diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 35b48d142b..75ccf9a600 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -54,7 +54,7 @@ def from_native_representation(self, native_value: Any) -> None: NotImplementedError: This configuration does not have a native representation ValueError: The value provided cannot be parsed as native representation """ - raise NotImplementedError() + raise ValueError() def to_native_representation(self) -> Any: """Represents the configuration instance in its native form ie. database connection string or JSON serialized GCP service credentials file. @@ -65,7 +65,7 @@ def to_native_representation(self) -> Any: Returns: Any: A native representation of the configuration """ - raise NotImplementedError() + raise ValueError() # implement dictionary-compatible interface on top of dataclass @@ -83,7 +83,7 @@ def __setitem__(self, __key: str, __value: Any) -> None: raise KeyError(__key) def __delitem__(self, __key: str) -> None: - raise NotImplementedError("Configuration fields cannot be deleted") + raise KeyError("Configuration fields cannot be deleted") def __iter__(self) -> Iterator[str]: return filter(lambda k: not k.startswith("__"), self.__fields_dict().__iter__()) diff --git a/dlt/common/exceptions.py b/dlt/common/exceptions.py index 1f88dac73b..440310c302 100644 --- a/dlt/common/exceptions.py +++ b/dlt/common/exceptions.py @@ -51,21 +51,18 @@ class TerminalException(Exception): """ Marks an exception that cannot be recovered from, should be mixed in into concrete exception class """ - pass class TransientException(Exception): """ Marks an exception in operation that can be retried, should be mixed in into concrete exception class """ - pass class TerminalValueError(ValueError, TerminalException): """ ValueError that is unrecoverable """ - pass class TimeRangeExhaustedException(DltException): diff --git a/dlt/common/json.py b/dlt/common/json.py index cb578c0357..6f2d3cf5d7 100644 --- a/dlt/common/json.py +++ b/dlt/common/json.py @@ -2,7 +2,7 @@ import pendulum from datetime import date, datetime # noqa: I251 from functools import partial -from typing import Any, Callable, Union +from typing import Any, Callable, List, Union from uuid import UUID from hexbytes import HexBytes import simplejson @@ -43,22 +43,22 @@ def custom_encode(obj: Any) -> str: # use PUA range to encode additional types -_DECIMAL = u'\uF026' -_DATETIME = u'\uF027' -_DATE = u'\uF028' -_UUIDT = u'\uF029' -_HEXBYTES = u'\uF02A' -_B64BYTES = u'\uF02B' -_WEI = u'\uF02C' +_DECIMAL = '\uF026' +_DATETIME = '\uF027' +_DATE = '\uF028' +_UUIDT = '\uF029' +_HEXBYTES = '\uF02A' +_B64BYTES = '\uF02B' +_WEI = '\uF02C' -DECODERS = [ - lambda s: Decimal(s), - lambda s: pendulum.parse(s), +DECODERS: List[Callable[[Any], Any]] = [ + Decimal, + pendulum.parse, lambda s: pendulum.parse(s).date(), # type: ignore - lambda s: UUID(s), - lambda s: HexBytes(s), - lambda s: base64.b64decode(s), - lambda s: Wei(s) + UUID, + HexBytes, + base64.b64decode, + Wei ] diff --git a/dlt/common/logger.py b/dlt/common/logger.py index 2c97c7f71b..53924c1e41 100644 --- a/dlt/common/logger.py +++ b/dlt/common/logger.py @@ -9,13 +9,15 @@ from dlt.common.json import json from dlt.common.typing import DictStrAny, StrStr -from dlt.common.configuration import RunConfiguration +from dlt.common.configuration.specs import RunConfiguration from dlt.common.utils import filter_env_vars + from dlt._version import common_version as __version__ DLT_LOGGER_NAME = "sv-dlt" LOGGER: Logger = None + def _add_logging_level(level_name: str, level: int, method_name:str = None) -> None: """ Comprehensively adds a new logging level to the `logging` module and the @@ -36,11 +38,11 @@ def _add_logging_level(level_name: str, level: int, method_name:str = None) -> N method_name = level_name.lower() if hasattr(logging, level_name): - raise AttributeError('{} already defined in logging module'.format(level_name)) + raise AttributeError('{} already defined in logging module'.format(level_name)) if hasattr(logging, method_name): - raise AttributeError('{} already defined in logging module'.format(method_name)) + raise AttributeError('{} already defined in logging module'.format(method_name)) if hasattr(logging.getLoggerClass(), method_name): - raise AttributeError('{} already defined in logger class'.format(method_name)) + raise AttributeError('{} already defined in logger class'.format(method_name)) # This method was inspired by the answers to Stack Overflow post # http://stackoverflow.com/q/2183233/2988730, especially diff --git a/dlt/common/runners/init.py b/dlt/common/runners/init.py index f53b08962f..41c536bf82 100644 --- a/dlt/common/runners/init.py +++ b/dlt/common/runners/init.py @@ -1,10 +1,9 @@ import threading -from typing import Type from dlt.common import logger -from dlt.common.configuration import RunConfiguration from dlt.common.logger import init_logging_from_config, init_telemetry from dlt.common.signals import register_signals +from dlt.common.configuration.specs import RunConfiguration # signals and telemetry should be initialized only once _INITIALIZED = False diff --git a/dlt/common/runners/pool_runner.py b/dlt/common/runners/pool_runner.py index be1a271de8..831dad5cdc 100644 --- a/dlt/common/runners/pool_runner.py +++ b/dlt/common/runners/pool_runner.py @@ -8,7 +8,7 @@ from dlt.common.time import sleep from dlt.common.telemetry import TRunHealth, TRunMetrics, get_logging_extras, get_metrics_from_prometheus from dlt.common.exceptions import SignalReceivedException, TimeRangeExhaustedException, UnsupportedProcessStartMethodException -from dlt.common.configuration import PoolRunnerConfiguration +from dlt.common.configuration.specs import PoolRunnerConfiguration HEALTH_PROPS_GAUGES: Dict[str, Union[Counter, Gauge]] = None diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index ba1a1af317..34fedce7fa 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -448,14 +448,14 @@ def _compile_regexes(self) -> None: self._compiled_preferred_types.append((utils.compile_simple_regex(pattern), dt)) for hint_name, hint_list in self._settings.get("default_hints", {}).items(): # compile hints which are column matching regexes - self._compiled_hints[hint_name] = list(map(lambda hint: utils.compile_simple_regex(hint), hint_list)) + self._compiled_hints[hint_name] = list(map(utils.compile_simple_regex, hint_list)) if self._schema_tables: for table in self._schema_tables.values(): if "filters" in table: if "excludes" in table["filters"]: - self._compiled_excludes[table["name"]] = list(map(lambda exclude: utils.compile_simple_regex(exclude), table["filters"]["excludes"])) + self._compiled_excludes[table["name"]] = list(map(utils.compile_simple_regex, table["filters"]["excludes"])) if "includes" in table["filters"]: - self._compiled_includes[table["name"]] = list(map(lambda exclude: utils.compile_simple_regex(exclude), table["filters"]["includes"])) + self._compiled_includes[table["name"]] = list(map(utils.compile_simple_regex, table["filters"]["includes"])) def __repr__(self) -> str: return f"Schema {self.name} at {id(self)}" diff --git a/dlt/common/signals.py b/dlt/common/signals.py index 2202bcb502..7a06d66d0f 100644 --- a/dlt/common/signals.py +++ b/dlt/common/signals.py @@ -12,16 +12,16 @@ exit_event = Event() -def signal_receiver(signal: int, frame: Any) -> None: +def signal_receiver(sig: int, frame: Any) -> None: global _received_signal - logger.info(f"Signal {signal} received") + logger.info(f"Signal {sig} received") if _received_signal > 0: logger.info(f"Another signal received after {_received_signal}") return - _received_signal = signal + _received_signal = sig # awake all threads sleeping on event exit_event.set() diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index 55b55b0711..4f5d2e3551 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -32,10 +32,11 @@ class LoaderStorageException(StorageException): class JobWithUnsupportedWriterException(LoaderStorageException): - def __init__(self, load_id: str, expected_file_format: Iterable[TLoaderFileFormat], wrong_job: str) -> None: + def __init__(self, load_id: str, expected_file_formats: Iterable[TLoaderFileFormat], wrong_job: str) -> None: self.load_id = load_id - self.expected_file_format = expected_file_format + self.expected_file_formats = expected_file_formats self.wrong_job = wrong_job + super().__init__(f"Job {wrong_job} for load id {load_id} requires loader file format that is not one of {expected_file_formats}") class SchemaStorageException(StorageException): diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index 47b6265f28..b74a6769de 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -1,6 +1,6 @@ from typing import Dict -from dlt.common.configuration import SchemaVolumeConfiguration +from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.schema.schema import Schema from dlt.common.storages.schema_storage import SchemaStorage diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 06d0ecfacf..aa5132e826 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -7,7 +7,7 @@ from dlt.common.typing import DictStrAny, StrAny from dlt.common.file_storage import FileStorage from dlt.common.data_writers import TLoaderFileFormat, DataWriter -from dlt.common.configuration import LoadVolumeConfiguration +from dlt.common.configuration.specs import LoadVolumeConfiguration from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema, TSchemaUpdate, TTableSchemaColumns from dlt.common.storages.versioned_storage import VersionedStorage diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index b0572f8423..9d800ca7d3 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -2,9 +2,8 @@ from itertools import groupby from pathlib import Path -from dlt.common.utils import chunks from dlt.common.file_storage import FileStorage -from dlt.common.configuration import NormalizeVolumeConfiguration +from dlt.common.configuration.specs import NormalizeVolumeConfiguration from dlt.common.storages.versioned_storage import VersionedStorage @@ -32,7 +31,7 @@ def list_files_to_normalize_sorted(self) -> Sequence[str]: return sorted(self.storage.list_folder_files(NormalizeStorage.EXTRACTED_FOLDER)) def get_grouped_iterator(self, files: Sequence[str]) -> "groupby[str, str]": - return groupby(files, lambda f: NormalizeStorage.get_schema_name(f)) + return groupby(files, NormalizeStorage.get_schema_name) @staticmethod def get_schema_name(file_name: str) -> str: diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 6393e6ab52..1d9937ecf6 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -4,11 +4,10 @@ from typing import Iterator, List, Mapping from dlt.common import json, logger -from dlt.common.configuration.specs.schema_volume_configuration import TSchemaFileFormat +from dlt.common.configuration.specs import SchemaVolumeConfiguration, TSchemaFileFormat from dlt.common.file_storage import FileStorage from dlt.common.schema import Schema, verify_schema_hash from dlt.common.typing import DictStrAny -from dlt.common.configuration import SchemaVolumeConfiguration from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError diff --git a/dlt/common/storages/versioned_storage.py b/dlt/common/storages/versioned_storage.py index 9669e076e0..85bddb9588 100644 --- a/dlt/common/storages/versioned_storage.py +++ b/dlt/common/storages/versioned_storage.py @@ -31,7 +31,7 @@ def __init__(self, version: semver.VersionInfo, is_owner: bool, storage: FileSto if is_owner: self._save_version(version) else: - raise WrongStorageVersionException(storage.storage_path, semver.VersionInfo.parse("0.0.0"), version) + raise WrongStorageVersionException(storage.storage_path, semver.VersionInfo.parse("0.0.0"), version) def migrate_storage(self, from_version: semver.VersionInfo, to_version: semver.VersionInfo) -> None: # migration example: diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 03ca293917..dbeed70f14 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -25,7 +25,11 @@ def uniq_id(len_: int = 16) -> str: def digest128(v: str) -> str: - return base64.b64encode(hashlib.shake_128(v.encode("utf-8")).digest(15)).decode('ascii') + return base64.b64encode( + hashlib.shake_128( + v.encode("utf-8") + ).digest(15) + ).decode('ascii') def digest256(v: str) -> str: @@ -155,16 +159,16 @@ def custom_environ(env: StrStr) -> Iterator[None]: def with_custom_environ(f: TFun) -> TFun: - @wraps(f) - def _wrap(*args: Any, **kwargs: Any) -> Any: - saved_environ = os.environ.copy() - try: - return f(*args, **kwargs) - finally: - os.environ.clear() - os.environ.update(saved_environ) + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> Any: + saved_environ = os.environ.copy() + try: + return f(*args, **kwargs) + finally: + os.environ.clear() + os.environ.update(saved_environ) - return _wrap # type: ignore + return _wrap # type: ignore def encoding_for_mode(mode: str) -> Optional[str]: diff --git a/dlt/common/validation.py b/dlt/common/validation.py index 3c6cec9ad9..9fc1349c10 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -34,7 +34,7 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: if is_literal_type(t): a_l = get_args(t) if pv not in a_l: - raise DictValidationException(f"In {path}: field {pk} value {pv} not in allowed {a_l}", path, pk, pv) + raise DictValidationException(f"In {path}: field {pk} value {pv} not in allowed {a_l}", path, pk, pv) elif t in [int, bool, str, float]: if not isinstance(pv, t): raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while {t.__name__} is expected", path, pk, pv) diff --git a/dlt/common/wei.py b/dlt/common/wei.py index 53babc23fc..218e5eee3a 100644 --- a/dlt/common/wei.py +++ b/dlt/common/wei.py @@ -1,8 +1,7 @@ from typing import Union -from dlt.common import Decimal from dlt.common.typing import TVariantRV, SupportsVariant -from dlt.common.arithmetics import default_context, decimal +from dlt.common.arithmetics import default_context, decimal, Decimal # default scale of EVM based blockchain WEI_SCALE = 18 diff --git a/dlt/dbt_runner/configuration.py b/dlt/dbt_runner/configuration.py index 947a5e5f15..3244535977 100644 --- a/dlt/dbt_runner/configuration.py +++ b/dlt/dbt_runner/configuration.py @@ -2,8 +2,9 @@ from typing import List, Optional, Type from dlt.common.typing import StrAny, TSecretValue +from dlt.common.configuration import make_configuration, configspec from dlt.common.configuration.providers import environ -from dlt.common.configuration import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials, make_configuration, configspec +from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials from . import __version__ diff --git a/dlt/dbt_runner/runner.py b/dlt/dbt_runner/runner.py index af4cf92825..afec511948 100644 --- a/dlt/dbt_runner/runner.py +++ b/dlt/dbt_runner/runner.py @@ -4,12 +4,12 @@ from prometheus_client.metrics import MetricWrapperBase from dlt.common import logger +from dlt.cli import TRunnerArgs from dlt.common.typing import DictStrAny, DictStrStr, StrAny from dlt.common.logger import is_json_logging from dlt.common.telemetry import get_logging_extras -from dlt.common.configuration import GcpClientCredentials +from dlt.common.configuration.specs import GcpClientCredentials from dlt.common.file_storage import FileStorage -from dlt.cli import TRunnerArgs from dlt.common.runners import initialize_runner, run_pool from dlt.common.telemetry import TRunMetrics @@ -32,30 +32,30 @@ def create_folders() -> Tuple[FileStorage, StrAny, Sequence[str], str, str]: - storage = FileStorage(CONFIG.package_volume_path, makedirs=True) - dbt_package_vars: DictStrAny = { + storage_ = FileStorage(CONFIG.package_volume_path, makedirs=True) + dbt_package_vars_: DictStrAny = { "source_schema_prefix": CONFIG.source_schema_prefix } if CONFIG.dest_schema_prefix: - dbt_package_vars["dest_schema_prefix"] = CONFIG.dest_schema_prefix + dbt_package_vars_["dest_schema_prefix"] = CONFIG.dest_schema_prefix if CONFIG.package_additional_vars: - dbt_package_vars.update(CONFIG.package_additional_vars) + dbt_package_vars_.update(CONFIG.package_additional_vars) # initialize dbt logging, returns global parameters to dbt command - global_args = initialize_dbt_logging(CONFIG.log_level, is_json_logging(CONFIG.log_format)) + global_args_ = initialize_dbt_logging(CONFIG.log_level, is_json_logging(CONFIG.log_format)) # generate path for the dbt package repo - repo_path = storage.make_full_path(CLONED_PACKAGE_NAME) + repo_path_ = storage_.make_full_path(CLONED_PACKAGE_NAME) # generate profile name - profile_name: str = None + profile_name_: str = None if CONFIG.package_profile_prefix: if isinstance(CONFIG, GcpClientCredentials): - profile_name = "%s_bigquery" % (CONFIG.package_profile_prefix) + profile_name_ = "%s_bigquery" % (CONFIG.package_profile_prefix) else: - profile_name = "%s_redshift" % (CONFIG.package_profile_prefix) + profile_name_ = "%s_redshift" % (CONFIG.package_profile_prefix) - return storage, dbt_package_vars, global_args, repo_path, profile_name + return storage_, dbt_package_vars_, global_args_, repo_path_, profile_name_ def create_gauges(registry: CollectorRegistry) -> Tuple[MetricWrapperBase, MetricWrapperBase]: diff --git a/dlt/helpers/streamlit.py b/dlt/helpers/streamlit.py index fc66fc3988..a65d0b2f0d 100644 --- a/dlt/helpers/streamlit.py +++ b/dlt/helpers/streamlit.py @@ -5,12 +5,13 @@ from typing import cast from copy import deepcopy +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration +from dlt.common.utils import dict_remove_nones_in_place + from dlt.pipeline import Pipeline from dlt.pipeline.typing import credentials_from_dict from dlt.pipeline.exceptions import MissingDependencyException, PipelineException from dlt.helpers.pandas import query_results_to_df, pd -from dlt.common.configuration.base_configuration import BaseConfiguration, CredentialsConfiguration -from dlt.common.utils import dict_remove_nones_in_place try: import streamlit as st @@ -59,12 +60,12 @@ def backup_pipeline(pipeline: Pipeline) -> None: if os.path.isfile(SECRETS_FILE_LOC): with open(SECRETS_FILE_LOC, "r", encoding="utf-8") as f: # use whitespace preserving parser - secrets = tomlkit.load(f) + secrets_ = tomlkit.load(f) else: - secrets = tomlkit.document() + secrets_ = tomlkit.document() # save general settings - secrets["dlt"] = { + secrets_["dlt"] = { "working_dir": pipeline.working_dir, "pipeline_name": pipeline.pipeline_name } @@ -76,13 +77,13 @@ def backup_pipeline(pipeline: Pipeline) -> None: # save client config # print(dict_remove_nones_in_place(CONFIG.as_dict(lowercase=False))) - dlt_c = cast(TomlContainer, secrets["dlt"]) + dlt_c = cast(TomlContainer, secrets_["dlt"]) dlt_c["destination"] = dict_remove_nones_in_place(dict(CONFIG)) dlt_c["credentials"] = dict_remove_nones_in_place(dict(CREDENTIALS)) with open(SECRETS_FILE_LOC, "w", encoding="utf-8") as f: # use whitespace preserving parser - tomlkit.dump(secrets, f) + tomlkit.dump(secrets_, f) def write_data_explorer_page(pipeline: Pipeline, schema_name: str = None, show_dlt_tables: bool = False, example_query: str = "", show_charts: bool = True) -> None: diff --git a/dlt/load/bigquery/client.py b/dlt/load/bigquery/client.py index fae2104d81..8340706a90 100644 --- a/dlt/load/bigquery/client.py +++ b/dlt/load/bigquery/client.py @@ -12,7 +12,7 @@ from dlt.common.typing import StrAny from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.configuration import GcpClientCredentials +from dlt.common.configuration.specs import GcpClientCredentials from dlt.common.data_writers import escape_bigquery_identifier from dlt.common.schema import TColumnSchema, TDataType, Schema, TTableSchemaColumns diff --git a/dlt/load/bigquery/configuration.py b/dlt/load/bigquery/configuration.py index b4c8d248e7..cd314b1382 100644 --- a/dlt/load/bigquery/configuration.py +++ b/dlt/load/bigquery/configuration.py @@ -3,7 +3,8 @@ from google.auth.exceptions import DefaultCredentialsError from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration, GcpClientCredentials, configspec +from dlt.common.configuration import make_configuration, configspec +from dlt.common.configuration.specs import GcpClientCredentials from dlt.common.configuration.exceptions import ConfigEntryMissingException from dlt.load.configuration import LoaderClientDwhConfiguration diff --git a/dlt/load/client_base.py b/dlt/load/client_base.py index 1d07a30f7d..7ed63c540b 100644 --- a/dlt/load/client_base.py +++ b/dlt/load/client_base.py @@ -5,7 +5,7 @@ from pathlib import Path from dlt.common import pendulum, logger -from dlt.common.configuration import BaseConfiguration, CredentialsConfiguration +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema from dlt.common.typing import StrAny diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index f502a48379..797a82e99b 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -1,7 +1,8 @@ from typing import Optional from dlt.common.typing import StrAny -from dlt.common.configuration import BaseConfiguration, PoolRunnerConfiguration, LoadVolumeConfiguration, TPoolType, make_configuration, configspec +from dlt.common.configuration import configspec, make_configuration +from dlt.common.configuration.specs import BaseConfiguration, PoolRunnerConfiguration, LoadVolumeConfiguration, TPoolType from . import __version__ diff --git a/dlt/load/dummy/client.py b/dlt/load/dummy/client.py index 6bb85c62cd..3b4888b77d 100644 --- a/dlt/load/dummy/client.py +++ b/dlt/load/dummy/client.py @@ -1,12 +1,11 @@ import random from types import TracebackType from typing import Dict, Tuple, Type -from dlt.common.data_writers import TLoaderFileFormat from dlt.common import pendulum from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchema -from dlt.common.configuration import CredentialsConfiguration +from dlt.common.configuration.specs import CredentialsConfiguration from dlt.common.typing import StrAny from dlt.load.client_base import JobClientBase, LoadJob, TLoaderCapabilities @@ -81,7 +80,7 @@ class DummyClient(JobClientBase): CONFIG: DummyClientConfiguration = None def __init__(self, schema: Schema) -> None: - super().__init__(schema) + pass def initialize_storage(self) -> None: pass diff --git a/dlt/load/exceptions.py b/dlt/load/exceptions.py index 6f7bb8a1d0..62a1cc67cb 100644 --- a/dlt/load/exceptions.py +++ b/dlt/load/exceptions.py @@ -11,12 +11,12 @@ def __init__(self, msg: str) -> None: class LoadClientTerminalException(LoadException, TerminalException): def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class LoadClientTransientException(LoadException, TransientException): def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class LoadClientTerminalInnerException(LoadClientTerminalException): diff --git a/dlt/load/redshift/client.py b/dlt/load/redshift/client.py index c35796e384..b4a65b65af 100644 --- a/dlt/load/redshift/client.py +++ b/dlt/load/redshift/client.py @@ -231,7 +231,7 @@ def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: raise LoadClientTerminalInnerException("Terminal error, file will not load", tr_ex) if "Numeric data overflow" in tr_ex.pgerror: raise LoadClientTerminalInnerException("Terminal error, file will not load", tr_ex) - if "Precision exceeds maximum": + if "Precision exceeds maximum" in tr_ex.pgerror: raise LoadClientTerminalInnerException("Terminal error, file will not load", tr_ex) raise LoadClientTransientInnerException("Error may go away, will retry", tr_ex) except (psycopg2.DataError, psycopg2.ProgrammingError, psycopg2.IntegrityError) as ter_ex: diff --git a/dlt/load/redshift/configuration.py b/dlt/load/redshift/configuration.py index c17c39292e..4f92b1a5bc 100644 --- a/dlt/load/redshift/configuration.py +++ b/dlt/load/redshift/configuration.py @@ -1,7 +1,8 @@ from typing import Tuple from dlt.common.typing import StrAny -from dlt.common.configuration import configspec, make_configuration, PostgresCredentials +from dlt.common.configuration import configspec, make_configuration +from dlt.common.configuration.specs import PostgresCredentials from dlt.load.configuration import LoaderClientDwhConfiguration diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index 53359f3c91..045aaadc2a 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -1,8 +1,7 @@ from dlt.common.typing import StrAny from dlt.common.data_writers import TLoaderFileFormat -from dlt.common.configuration import (PoolRunnerConfiguration, NormalizeVolumeConfiguration, - LoadVolumeConfiguration, SchemaVolumeConfiguration, - TPoolType, make_configuration, configspec) +from dlt.common.configuration import make_configuration, configspec +from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType, NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration from . import __version__ diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 2e55ba00b8..d0f1a23174 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -11,7 +11,8 @@ from dlt.common import json, sleep, signals, logger from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner -from dlt.common.configuration import PoolRunnerConfiguration, make_configuration +from dlt.common.configuration import make_configuration +from dlt.common.configuration.specs import PoolRunnerConfiguration from dlt.common.file_storage import FileStorage from dlt.common.schema import Schema from dlt.common.typing import DictStrAny, StrAny diff --git a/experiments/pipeline/configuration.py b/experiments/pipeline/configuration.py index 4bfddc0ac4..b5d9fc933e 100644 --- a/experiments/pipeline/configuration.py +++ b/experiments/pipeline/configuration.py @@ -8,7 +8,7 @@ from functools import wraps from dlt.common.typing import DictStrAny, StrAny, TAny, TFun -from dlt.common.configuration import BaseConfiguration +from dlt.common.configuration.specs import BaseConfiguration from dlt.common.configuration.resolve import NON_EVAL_TYPES, make_configuration, SIMPLE_TYPES # _POS_PARAMETER_KINDS = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, Parameter.VAR_POSITIONAL) diff --git a/tests/common/runners/test_runners.py b/tests/common/runners/test_runners.py index 29dcf7027f..220e62dd58 100644 --- a/tests/common/runners/test_runners.py +++ b/tests/common/runners/test_runners.py @@ -5,8 +5,8 @@ from dlt.cli import TRunnerArgs from dlt.common import signals -from dlt.common.configuration import PoolRunnerConfiguration, make_configuration, configspec -from dlt.common.configuration.specs.pool_runner_configuration import TPoolType +from dlt.common.configuration import make_configuration, configspec +from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType from dlt.common.exceptions import DltException, SignalReceivedException, TimeRangeExhaustedException, UnsupportedProcessStartMethodException from dlt.common.runners import pool_runner as runner diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 5d58b93434..f8e39ae099 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -2,7 +2,8 @@ import pytest from dlt.common import pendulum -from dlt.common.configuration import SchemaVolumeConfiguration, make_configuration +from dlt.common.configuration import make_configuration +from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.exceptions import DictValidationException from dlt.common.schema.typing import TColumnName, TSimpleRegex, COLUMN_HINTS from dlt.common.typing import DictStrAny, StrAny diff --git a/tests/common/storages/test_loader_storage.py b/tests/common/storages/test_loader_storage.py index d6143fa7e8..84a938d663 100644 --- a/tests/common/storages/test_loader_storage.py +++ b/tests/common/storages/test_loader_storage.py @@ -6,7 +6,8 @@ from dlt.common import sleep from dlt.common.schema import Schema from dlt.common.storages.load_storage import LoadStorage, TParsedJobFileName -from dlt.common.configuration import LoadVolumeConfiguration, make_configuration +from dlt.common.configuration import make_configuration +from dlt.common.configuration.specs import LoadVolumeConfiguration from dlt.common.storages.exceptions import NoMigrationPathException from dlt.common.typing import StrAny from dlt.common.utils import uniq_id diff --git a/tests/common/storages/test_normalize_storage.py b/tests/common/storages/test_normalize_storage.py index 88e0a28c3f..18c5d1e601 100644 --- a/tests/common/storages/test_normalize_storage.py +++ b/tests/common/storages/test_normalize_storage.py @@ -1,13 +1,14 @@ import pytest -from dlt.common.storages.exceptions import NoMigrationPathException +from dlt.common.utils import uniq_id from dlt.common.storages import NormalizeStorage -from dlt.common.configuration import NormalizeVolumeConfiguration +from dlt.common.storages.exceptions import NoMigrationPathException +from dlt.common.configuration.specs import NormalizeVolumeConfiguration from dlt.common.storages.normalize_storage import TParsedNormalizeFileName -from dlt.common.utils import uniq_id from tests.utils import write_version, autouse_test_storage + @pytest.mark.skip() def test_load_events_and_group_by_sender() -> None: # TODO: create fixture with two sender ids and 3 files and check the result diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 547ac67a11..e73d07987d 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -4,15 +4,15 @@ import yaml from dlt.common import json -from dlt.common.configuration import make_configuration +from dlt.common.typing import DictStrAny from dlt.common.file_storage import FileStorage from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema from dlt.common.schema.utils import default_normalizers -from dlt.common.configuration import SchemaVolumeConfiguration +from dlt.common.configuration import make_configuration +from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError from dlt.common.storages import SchemaStorage, LiveSchemaStorage -from dlt.common.typing import DictStrAny from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT from tests.common.utils import load_yml_case, yml_case_path diff --git a/tests/common/test_configuration.py b/tests/common/test_configuration.py index c6b1b6c77a..197b8286f9 100644 --- a/tests/common/test_configuration.py +++ b/tests/common/test_configuration.py @@ -3,10 +3,8 @@ from typing import Any, Dict, List, Mapping, MutableMapping, NewType, Optional, Tuple, Type from dlt.common.typing import TSecretValue -from dlt.common.configuration import ( - RunConfiguration, ConfigEntryMissingException, ConfigFileNotFoundException, - ConfigEnvValueCannotBeCoercedException, BaseConfiguration, resolve, configspec) -from dlt.common.configuration.resolve import make_configuration +from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, ConfigEnvValueCannotBeCoercedException, resolve +from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration from dlt.common.configuration.providers import environ as environ_provider from dlt.common.utils import custom_environ @@ -191,47 +189,47 @@ def environment() -> Any: def test_initial_config_value() -> None: # set from init method - C = make_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) + C = resolve.make_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) assert C.to_native_representation() == "h>a>b>he" # set from native form - C = make_configuration(InstrumentedConfiguration(), initial_value="h>a>b>he") + C = resolve.make_configuration(InstrumentedConfiguration(), initial_value="h>a>b>he") assert C.head == "h" assert C.tube == ["a", "b"] assert C.heels == "he" # set from dictionary - C = make_configuration(InstrumentedConfiguration(), initial_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) + C = resolve.make_configuration(InstrumentedConfiguration(), initial_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) assert C.to_native_representation() == "h>tu>be>xhe" def test_check_integrity() -> None: with pytest.raises(RuntimeError): # head over hells - make_configuration(InstrumentedConfiguration(), initial_value="he>a>b>h") + resolve.make_configuration(InstrumentedConfiguration(), initial_value="he>a>b>h") def test_embedded_config(environment: Any) -> None: # resolve all embedded config, using initial value for instrumented config and initial dict for namespaced config - C = make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "instrumented": "h>tu>be>xhe", "namespaced": {"password": "pwd"}}) + C = resolve.make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "instrumented": "h>tu>be>xhe", "namespaced": {"password": "pwd"}}) assert C.default == "set" assert C.instrumented.to_native_representation() == "h>tu>be>xhe" assert C.namespaced.password == "pwd" # resolve but providing values via env with custom_environ({"INSTRUMENTED": "h>tu>u>be>xhe", "DLT_TEST__PASSWORD": "passwd", "DEFAULT": "DEF"}): - C = make_configuration(EmbeddedConfiguration()) + C = resolve.make_configuration(EmbeddedConfiguration()) assert C.default == "DEF" assert C.instrumented.to_native_representation() == "h>tu>u>be>xhe" assert C.namespaced.password == "passwd" # resolve partial, partial is passed to embedded - C = make_configuration(EmbeddedConfiguration(), accept_partial=True) + C = resolve.make_configuration(EmbeddedConfiguration(), accept_partial=True) assert C.__is_partial__ assert C.namespaced.__is_partial__ assert C.instrumented.__is_partial__ # some are partial, some are not with custom_environ({"DLT_TEST__PASSWORD": "passwd"}): - C = make_configuration(EmbeddedConfiguration(), accept_partial=True) + C = resolve.make_configuration(EmbeddedConfiguration(), accept_partial=True) assert C.__is_partial__ assert not C.namespaced.__is_partial__ assert C.instrumented.__is_partial__ @@ -239,23 +237,23 @@ def test_embedded_config(environment: Any) -> None: # single integrity error fails all the embeds with custom_environ({"INSTRUMENTED": "he>tu>u>be>h"}): with pytest.raises(RuntimeError): - make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) + resolve.make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) # part via env part via initial values with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): - C = make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) + C = resolve.make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) assert C.instrumented.to_native_representation() == "h>tu>u>be>he" def test_provider_values_over_initial(environment: Any) -> None: with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): - C = make_configuration(EmbeddedConfiguration(), initial_value={"instrumented": "h>tu>be>xhe"}, accept_partial=True) + C = resolve.make_configuration(EmbeddedConfiguration(), initial_value={"instrumented": "h>tu>be>xhe"}, accept_partial=True) assert C.instrumented.to_native_representation() == "h>tu>u>be>he" assert not C.instrumented.__is_partial__ def test_run_configuration_gen_name(environment: Any) -> None: - C = make_configuration(RunConfiguration()) + C = resolve.make_configuration(RunConfiguration()) assert C.pipeline_name.startswith("dlt_") @@ -275,7 +273,7 @@ def test_configuration_is_mutable_mapping(environment: Any) -> None: assert dict(SecretConfiguration()) == expected_dict environment["SECRET_VALUE"] = "secret" - C = make_configuration(SecretConfiguration()) + C = resolve.make_configuration(SecretConfiguration()) expected_dict["secret_value"] = "secret" assert dict(C) == expected_dict @@ -358,7 +356,7 @@ def test_raises_on_unresolved_fields() -> None: # via make configuration with pytest.raises(ConfigEntryMissingException) as config_entry_missing_exception: - make_configuration(WrongConfiguration()) + resolve.make_configuration(WrongConfiguration()) assert 'NONECONFIGVAR' in config_entry_missing_exception.value.missing_set @@ -367,9 +365,9 @@ def test_optional_types_are_not_required() -> None: keys = resolve._get_resolvable_fields(ConfigurationWithOptionalTypes()) resolve._is_config_bounded(ConfigurationWithOptionalTypes(), keys) # make optional config - make_configuration(ConfigurationWithOptionalTypes()) + resolve.make_configuration(ConfigurationWithOptionalTypes()) # make config with optional values - make_configuration(ProdConfigurationWithOptionalTypes(), initial_value={"INT_VAL": None}) + resolve.make_configuration(ProdConfigurationWithOptionalTypes(), initial_value={"INT_VAL": None}) def test_configuration_apply_adds_environment_variable_to_config(environment: Any) -> None: @@ -464,7 +462,7 @@ def test_initial_values(environment: Any) -> None: environment["PIPELINE_NAME"] = "env name" environment["CREATED_VAL"] = "12837" # set initial values and allow partial config - C = make_configuration(CoercionTestConfiguration(), + C = resolve.make_configuration(CoercionTestConfiguration(), {"pipeline_name": "initial name", "none_val": type(environment), "created_val": 878232, "bytes_val": b"str"}, accept_partial=True ) @@ -481,7 +479,7 @@ def test_accept_partial(environment: Any) -> None: # modify original type WrongConfiguration.NoneConfigVar = None # that None value will be present in the instance - C = make_configuration(WrongConfiguration(), accept_partial=True) + C = resolve.make_configuration(WrongConfiguration(), accept_partial=True) assert C.NoneConfigVar is None # partial resolution assert C.__is_partial__ diff --git a/tests/common/test_logging.py b/tests/common/test_logging.py index 8d0c95f3b4..7c6cf9ef11 100644 --- a/tests/common/test_logging.py +++ b/tests/common/test_logging.py @@ -6,7 +6,8 @@ from dlt import __version__ as auto_version from dlt.common import logger, sleep from dlt.common.typing import StrStr -from dlt.common.configuration import RunConfiguration, configspec +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import RunConfiguration from tests.utils import preserve_environ @@ -31,6 +32,8 @@ class SentryLoggerConfiguration(JsonLoggerConfiguration): sentry_dsn: str = "http://user:pass@localhost/818782" +import dataclasses +@dataclasses.dataclass @configspec(init=True) class SentryLoggerCriticalConfiguration(SentryLoggerConfiguration): log_level: str = "CRITICAL" diff --git a/tests/conftest.py b/tests/conftest.py index 75eb19d185..3e5c53ce52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ def pytest_configure(config): # the dataclass implementation will use those patched values when creating instances (the values present # in the declaration are not frozen allowing patching) - from dlt.common.configuration import RunConfiguration, LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration + from dlt.common.configuration.specs import RunConfiguration, LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration test_storage_root = "_storage" RunConfiguration.config_files_storage_path = os.path.join(test_storage_root, "config/%s") diff --git a/tests/dbt_runner/test_runner_bigquery.py b/tests/dbt_runner/test_runner_bigquery.py index 8be838a012..bed8b2304b 100644 --- a/tests/dbt_runner/test_runner_bigquery.py +++ b/tests/dbt_runner/test_runner_bigquery.py @@ -2,10 +2,10 @@ import pytest from dlt.common import logger -from dlt.common.configuration import GcpClientCredentials +from dlt.common.configuration.specs import GcpClientCredentials from dlt.common.telemetry import TRunMetrics, get_metrics_from_prometheus from dlt.common.typing import StrStr -from dlt.common.utils import uniq_id, with_custom_environ +from dlt.common.utils import uniq_id from dlt.dbt_runner.utils import DBTProcessingError from dlt.dbt_runner import runner diff --git a/tests/dbt_runner/test_runner_redshift.py b/tests/dbt_runner/test_runner_redshift.py index 176cf2786a..455d893788 100644 --- a/tests/dbt_runner/test_runner_redshift.py +++ b/tests/dbt_runner/test_runner_redshift.py @@ -4,8 +4,8 @@ from prometheus_client import CollectorRegistry from dlt.common import logger -from dlt.common.configuration import PostgresCredentials -from dlt.common.configuration.resolve import make_configuration +from dlt.common.configuration import make_configuration +from dlt.common.configuration.specs import PostgresCredentials from dlt.common.file_storage import FileStorage from dlt.common.telemetry import TRunMetrics, get_metrics_from_prometheus from dlt.common.typing import StrStr diff --git a/tests/dbt_runner/utils.py b/tests/dbt_runner/utils.py index 4dabd56957..baf762ed28 100644 --- a/tests/dbt_runner/utils.py +++ b/tests/dbt_runner/utils.py @@ -21,7 +21,7 @@ def restore_secret_storage_path() -> None: def load_secret(name: str) -> str: environ.SECRET_STORAGE_PATH = "./tests/dbt_runner/secrets/%s" - secret = environ._get_key_value(name, environ.TSecretValue) + secret = environ.get_key(name, environ.TSecretValue) if not secret: raise FileNotFoundError(environ.SECRET_STORAGE_PATH % name) return secret diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 34d44a1aa5..e3f677e4d4 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -4,7 +4,8 @@ from dlt.common.utils import custom_environ, uniq_id from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.configuration import make_configuration, GcpClientCredentials +from dlt.common.configuration import make_configuration +from dlt.common.configuration.specs import GcpClientCredentials from dlt.load.bigquery.client import BigQueryClient from dlt.load.exceptions import LoadClientSchemaWillNotUpdate diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index b273b1dbec..0c2fcfd77c 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -4,7 +4,8 @@ from dlt.common.utils import uniq_id, custom_environ from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.configuration import PostgresCredentials, make_configuration +from dlt.common.configuration import make_configuration +from dlt.common.configuration.specs import PostgresCredentials from dlt.load.exceptions import LoadClientSchemaWillNotUpdate from dlt.load.redshift.client import RedshiftClient diff --git a/tests/tools/create_storages.py b/tests/tools/create_storages.py index 680f3dd61f..e3e1f98865 100644 --- a/tests/tools/create_storages.py +++ b/tests/tools/create_storages.py @@ -1,5 +1,5 @@ from dlt.common.storages import NormalizeStorage, LoadStorage, SchemaStorage -from dlt.common.configuration import NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration +from dlt.common.configuration.specs import NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration NormalizeStorage(True, NormalizeVolumeConfiguration) diff --git a/tests/utils.py b/tests/utils.py index 57610f43bf..481f2eb350 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ from os import environ from dlt.common.configuration.resolve import _get_resolvable_fields, make_configuration -from dlt.common.configuration import RunConfiguration +from dlt.common.configuration.specs import RunConfiguration from dlt.common.logger import init_logging_from_config from dlt.common.file_storage import FileStorage from dlt.common.schema import Schema @@ -68,11 +68,11 @@ def clean_test_storage(init_normalize: bool = False, init_loader: bool = False) storage.create_folder(".") if init_normalize: from dlt.common.storages import NormalizeStorage - from dlt.common.configuration import NormalizeVolumeConfiguration + from dlt.common.configuration.specs import NormalizeVolumeConfiguration NormalizeStorage(True, NormalizeVolumeConfiguration) if init_loader: from dlt.common.storages import LoadStorage - from dlt.common.configuration import LoadVolumeConfiguration + from dlt.common.configuration.specs import LoadVolumeConfiguration LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) return storage From c0cc1d630f11bb933f7ed6b63ba547247d000bc5 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Oct 2022 18:34:18 +0200 Subject: [PATCH 23/66] moves file_storage to storages, writes lists of items into jsonl files between extract and normalize stages --- dlt/common/data_writers/writers.py | 7 +- dlt/common/normalizers/json/relational.py | 41 +++++----- dlt/common/schema/utils.py | 2 +- dlt/common/storages/__init__.py | 1 + dlt/common/{ => storages}/file_storage.py | 0 dlt/common/storages/load_storage.py | 2 +- dlt/common/storages/normalize_storage.py | 2 +- dlt/common/storages/schema_storage.py | 2 +- dlt/common/storages/versioned_storage.py | 2 +- dlt/dbt_runner/runner.py | 2 +- dlt/extract/extractor_storage.py | 2 +- dlt/normalize/normalize.py | 78 ++++++++++--------- dlt/pipeline/pipeline.py | 2 +- experiments/pipeline/extract.py | 2 +- experiments/pipeline/pipe.py | 2 +- experiments/pipeline/pipeline.py | 2 +- .../normalizers/test_json_relational.py | 13 ++++ tests/common/storages/test_file_storage.py | 2 +- tests/common/storages/test_schema_storage.py | 3 +- .../common/storages/test_versioned_storage.py | 2 +- tests/common/test_configuration.py | 4 +- tests/common/test_logging.py | 2 - tests/dbt_runner/test_runner_redshift.py | 2 +- tests/dbt_runner/test_utils.py | 2 +- tests/load/bigquery/test_bigquery_client.py | 2 +- tests/load/redshift/test_redshift_client.py | 2 +- tests/load/test_client.py | 2 +- tests/load/test_dummy_client.py | 4 +- tests/load/utils.py | 3 +- tests/utils.py | 2 +- 30 files changed, 107 insertions(+), 87 deletions(-) rename dlt/common/{ => storages}/file_storage.py (100%) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index d2d788b64e..6656afac92 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -61,7 +61,7 @@ def class_factory(file_format: TLoaderFileFormat) -> Type["DataWriter"]: if file_format == "jsonl": return JsonlWriter elif file_format == "puae-jsonl": - return JsonlPUAEncodeWriter + return JsonlListPUAEncodeWriter elif file_format == "insert_values": return InsertValuesWriter else: @@ -87,14 +87,15 @@ def data_format(cls) -> TFileFormatSpec: return TFileFormatSpec("jsonl", "jsonl", False, True) -class JsonlPUAEncodeWriter(JsonlWriter): +class JsonlListPUAEncodeWriter(JsonlWriter): def write_data(self, rows: Sequence[Any]) -> None: # skip JsonlWriter when calling super super(JsonlWriter, self).write_data(rows) # encode types with PUA characters with jsonlines.Writer(self._f, dumps=json_typed_dumps) as w: - w.write_all(rows) + # write all rows as one list which will require to write just one line + w.write_all([rows]) @classmethod def data_format(cls) -> TFileFormatSpec: diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 3fb67a9364..622457c20d 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -10,16 +10,16 @@ from dlt.common.validation import validate_dict -class TEventRow(TypedDict, total=False): +class TDataItemRow(TypedDict, total=False): _dlt_id: str # unique id of current row -class TEventRowRoot(TEventRow, total=False): +class TDataItemRowRoot(TDataItemRow, total=False): _dlt_load_id: str # load id to identify records loaded together that ie. need to be processed _dlt_meta: TEventDLTMeta # stores metadata, should never be sent to the normalizer -class TEventRowChild(TEventRow, total=False): +class TDataItemRowChild(TDataItemRow, total=False): _dlt_root_id: str # unique id of top level parent _dlt_parent_id: str # unique id of parent row _dlt_list_idx: int # position in the list of rows @@ -56,7 +56,7 @@ def _is_complex_type(schema: Schema, table_name: str, field_name: str, _r_lvl: i return data_type == "complex" -def _flatten(schema: Schema, table: str, dict_row: TEventRow, _r_lvl: int) -> Tuple[TEventRow, Dict[str, Sequence[Any]]]: +def _flatten(schema: Schema, table: str, dict_row: TDataItemRow, _r_lvl: int) -> Tuple[TDataItemRow, Dict[str, Sequence[Any]]]: out_rec_row: DictStrAny = {} out_rec_list: Dict[str, Sequence[Any]] = {} @@ -82,7 +82,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, parent_name: Optional[str]) - out_rec_row[child_name] = v norm_row_dicts(dict_row, _r_lvl, None) - return cast(TEventRow, out_rec_row), out_rec_list + return cast(TDataItemRow, out_rec_row), out_rec_list def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> str: @@ -91,18 +91,22 @@ def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> return digest128(f"{parent_row_id}_{child_table}_{list_idx}") -def _add_linking(row: TEventRowChild, extend: DictStrAny, parent_row_id: str, list_idx: int) -> TEventRowChild: +def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDataItemRowChild: row["_dlt_parent_id"] = parent_row_id row["_dlt_list_idx"] = list_idx return row +def _extend_row(extend: DictStrAny, row: TDataItemRow) -> None: + row.update(extend) # type: ignore + + def _get_content_hash(schema: Schema, table: str, row: StrAny) -> str: return digest128(uniq_id()) -def _get_propagated_values(schema: Schema, table: str, row: TEventRow, is_top_level: bool) -> StrAny: +def _get_propagated_values(schema: Schema, table: str, row: TDataItemRow, is_top_level: bool) -> StrAny: config: JSONNormalizerConfigPropagation = (schema._normalizers_config["json"].get("config") or {}).get("propagation", None) extend: DictStrAny = {} if config: @@ -120,10 +124,6 @@ def _get_propagated_values(schema: Schema, table: str, row: TEventRow, is_top_le return extend -def _extend_row(extend: DictStrAny, row: TEventRow) -> None: - row.update(extend) # type: ignore - - # generate child tables only for lists def _normalize_list( schema: Schema, @@ -135,7 +135,7 @@ def _normalize_list( _r_lvl: int = 0 ) -> TNormalizedRowIterator: - v: TEventRowChild = None + v: TDataItemRowChild = None for idx, v in enumerate(seq): # yield child table row if isinstance(v, dict): @@ -149,14 +149,14 @@ def _normalize_list( child_row_hash = _get_child_row_hash(parent_row_id, table, idx) wrap_v = wrap_in_dict(v) wrap_v["_dlt_id"] = child_row_hash - e = _add_linking(wrap_v, extend, parent_row_id, idx) + e = _link_row(wrap_v, parent_row_id, idx) _extend_row(extend, e) yield (table, parent_table), e def _normalize_row( schema: Schema, - dict_row: TEventRow, + dict_row: TDataItemRow, extend: DictStrAny, table: str, parent_table: Optional[str] = None, @@ -177,12 +177,12 @@ def _normalize_row( primary_key = schema.filter_row_with_hint(table, "primary_key", flattened_row) if primary_key: # create row id from primary key - row_id = digest128("_".join(map(lambda v: str(v), primary_key.values()))) + row_id = digest128("_".join(map(str, primary_key.values()))) elif not is_top_level: # child table row deterministic hash row_id = _get_child_row_hash(parent_row_id, table, pos) # link to parent table - _add_linking(cast(TEventRowChild, flattened_row), extend, parent_row_id, pos) + _link_row(cast(TDataItemRowChild, flattened_row), parent_row_id, pos) else: # create hash based on the content of the row row_id = _get_content_hash(schema, table, flattened_row) @@ -222,8 +222,11 @@ def extend_schema(schema: Schema) -> None: def normalize_data_item(schema: Schema, item: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: + # wrap items that are not dictionaries in dictionary, otherwise they cannot be processed by the JSON normalizer + if not isinstance(item, dict): + item = wrap_in_dict(item) # we will extend event with all the fields necessary to load it as root row - event = cast(TEventRowRoot, item) + row = cast(TDataItemRowRoot, item) # identify load id if loaded data must be processed after loading incrementally - event["_dlt_load_id"] = load_id - yield from _normalize_row(schema, cast(TEventRowChild, event), {}, schema.normalize_table_name(table_name)) + row["_dlt_load_id"] = load_id + yield from _normalize_row(schema, cast(TDataItemRowChild, row), {}, schema.normalize_table_name(table_name)) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index ec35c3bb31..040337e4ef 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -552,7 +552,7 @@ def new_table(table_name: str, parent_name: str = None, write_disposition: TWrit else: # set write disposition only for root tables table["write_disposition"] = write_disposition or DEFAULT_WRITE_DISPOSITION - print(f"new table {table_name} cid {id(table['columns'])}") + # print(f"new table {table_name} cid {id(table['columns'])}") return table diff --git a/dlt/common/storages/__init__.py b/dlt/common/storages/__init__.py index 9cae20f688..68d8c4aea4 100644 --- a/dlt/common/storages/__init__.py +++ b/dlt/common/storages/__init__.py @@ -1,3 +1,4 @@ +from .file_storage import FileStorage # noqa: F401 from .schema_storage import SchemaStorage # noqa: F401 from .live_schema_storage import LiveSchemaStorage # noqa: F401 from .normalize_storage import NormalizeStorage # noqa: F401 diff --git a/dlt/common/file_storage.py b/dlt/common/storages/file_storage.py similarity index 100% rename from dlt/common/file_storage.py rename to dlt/common/storages/file_storage.py diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index aa5132e826..13c5dac4b3 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -5,7 +5,7 @@ from dlt.common import json, pendulum from dlt.common.typing import DictStrAny, StrAny -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.data_writers import TLoaderFileFormat, DataWriter from dlt.common.configuration.specs import LoadVolumeConfiguration from dlt.common.exceptions import TerminalValueError diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 9d800ca7d3..446bec0063 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -2,7 +2,7 @@ from itertools import groupby from pathlib import Path -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.configuration.specs import NormalizeVolumeConfiguration from dlt.common.storages.versioned_storage import VersionedStorage diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 1d9937ecf6..ed1f0f0513 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -5,7 +5,7 @@ from dlt.common import json, logger from dlt.common.configuration.specs import SchemaVolumeConfiguration, TSchemaFileFormat -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import Schema, verify_schema_hash from dlt.common.typing import DictStrAny diff --git a/dlt/common/storages/versioned_storage.py b/dlt/common/storages/versioned_storage.py index 85bddb9588..9dad05f9cc 100644 --- a/dlt/common/storages/versioned_storage.py +++ b/dlt/common/storages/versioned_storage.py @@ -1,6 +1,6 @@ import semver -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException diff --git a/dlt/dbt_runner/runner.py b/dlt/dbt_runner/runner.py index afec511948..07c6e1edc5 100644 --- a/dlt/dbt_runner/runner.py +++ b/dlt/dbt_runner/runner.py @@ -9,7 +9,7 @@ from dlt.common.logger import is_json_logging from dlt.common.telemetry import get_logging_extras from dlt.common.configuration.specs import GcpClientCredentials -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.runners import initialize_runner, run_pool from dlt.common.telemetry import TRunMetrics diff --git a/dlt/extract/extractor_storage.py b/dlt/extract/extractor_storage.py index c116b2fb08..ce7d769c43 100644 --- a/dlt/extract/extractor_storage.py +++ b/dlt/extract/extractor_storage.py @@ -3,7 +3,7 @@ from dlt.common.json import json_typed_dumps from dlt.common.typing import Any from dlt.common.utils import uniq_id -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.storages import VersionedStorage, NormalizeStorage diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 07163b2c87..89b3afd0e6 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -6,7 +6,6 @@ from dlt.common import pendulum, signals, json, logger from dlt.common.json import custom_pua_decode from dlt.cli import TRunnerArgs -from dlt.common.normalizers.json import wrap_in_dict from dlt.common.runners import TRunMetrics, Runnable, run_pool, initialize_runner from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.storages.exceptions import SchemaNotFoundError @@ -16,7 +15,6 @@ from dlt.common.exceptions import PoolException from dlt.common.schema import TSchemaUpdate, Schema from dlt.common.schema.exceptions import CannotCoerceColumnException -from dlt.common.utils import uniq_id from dlt.normalize.configuration import configuration, NormalizeConfiguration @@ -82,50 +80,24 @@ def w_normalize_files(CONFIG: NormalizeConfiguration, schema_name: str, load_id: normalize_storage = NormalizeStorage(False, CONFIG) schema_update: TSchemaUpdate = {} - column_schemas: Dict[str, TTableSchemaColumns] = {} # quick access to column schema for writers below total_items = 0 # process all files with data items and write to buffered item storage try: for extracted_items_file in extracted_items_files: line_no: int = 0 - item: TDataItem = None - parent_table_name = NormalizeStorage.parse_normalize_file_name(extracted_items_file).table_name - logger.debug(f"Processing extracted items in {extracted_items_file} in load_id {load_id} with table name {parent_table_name} and schema {schema_name}") + root_table_name = NormalizeStorage.parse_normalize_file_name(extracted_items_file).table_name + logger.debug(f"Processing extracted items in {extracted_items_file} in load_id {load_id} with table name {root_table_name} and schema {schema_name}") with normalize_storage.storage.open_file(extracted_items_file) as f: # enumerate jsonl file line by line for line_no, line in enumerate(f): - item = json.loads(line) - if not isinstance(item, dict): - item = wrap_in_dict(item) - for (table_name, parent_table), row in schema.normalize_data_item(schema, item, load_id, parent_table_name): - # filter row, may eliminate some or all fields - row = schema.filter_row(table_name, row) - # do not process empty rows - if row: - # decode pua types - for k, v in row.items(): - row[k] = custom_pua_decode(v) # type: ignore - # coerce row of values into schema table, generating partial table with new columns if any - row, partial_table = schema.coerce_row(table_name, parent_table, row) - if partial_table: - # update schema and save the change - schema.update_schema(partial_table) - table_updates = schema_update.setdefault(table_name, []) - table_updates.append(partial_table) - # get current columns schema - columns = column_schemas.get(table_name) - if not columns: - columns = schema.get_table_columns(table_name) - column_schemas[table_name] = columns - # store row - load_storage.write_data_item(load_id, schema_name, table_name, row, columns) - # count total items - total_items += 1 - if line_no > 0 and line_no % 100 == 0: - logger.debug(f"Processed {line_no} items from file {extracted_items_file}, total items {total_items}") + items: List[TDataItem] = json.loads(line) + partial_update, items_count = Normalize._w_normalize_chunk(load_storage, schema, load_id, root_table_name, items) + schema_update.update(partial_update) + total_items += items_count + logger.debug(f"Processed {line_no} items from file {extracted_items_file}, items {items_count} of total {total_items}") # if any item found in the file - if item: + if items_count > 0: logger.debug(f"Processed total {line_no + 1} lines from file {extracted_items_file}, total items {total_items}") except Exception: logger.exception(f"Exception when processing file {extracted_items_file}, line {line_no}") @@ -138,6 +110,40 @@ def w_normalize_files(CONFIG: NormalizeConfiguration, schema_name: str, load_id: return schema_update, total_items + @staticmethod + def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, root_table_name: str, items: List[TDataItem]) -> Tuple[TSchemaUpdate, int]: + column_schemas: Dict[str, TTableSchemaColumns] = {} # quick access to column schema for writers below + schema_update: TSchemaUpdate = {} + schema_name = schema.name + items_count = 0 + + for item in items: + for (table_name, parent_table), row in schema.normalize_data_item(schema, item, load_id, root_table_name): + # filter row, may eliminate some or all fields + row = schema.filter_row(table_name, row) + # do not process empty rows + if row: + # decode pua types + for k, v in row.items(): + row[k] = custom_pua_decode(v) # type: ignore + # coerce row of values into schema table, generating partial table with new columns if any + row, partial_table = schema.coerce_row(table_name, parent_table, row) + if partial_table: + # update schema and save the change + schema.update_schema(partial_table) + table_updates = schema_update.setdefault(table_name, []) + table_updates.append(partial_table) + # get current columns schema + columns = column_schemas.get(table_name) + if not columns: + columns = schema.get_table_columns(table_name) + column_schemas[table_name] = columns + # store row + load_storage.write_data_item(load_id, schema_name, table_name, row, columns) + # count total items + items_count += 1 + return schema_update, items_count + def map_parallel(self, schema_name: str, load_id: str, files: Sequence[str]) -> TMapFuncRV: # TODO: maybe we should chunk by file size, now map all files to workers chunk_files = [files] diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index d0f1a23174..eaf6f4116a 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -13,7 +13,7 @@ from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner from dlt.common.configuration import make_configuration from dlt.common.configuration.specs import PoolRunnerConfiguration -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema import Schema from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id, is_interactive diff --git a/experiments/pipeline/extract.py b/experiments/pipeline/extract.py index f7b4a6e12d..8acf0e014f 100644 --- a/experiments/pipeline/extract.py +++ b/experiments/pipeline/extract.py @@ -5,7 +5,7 @@ from dlt.common.sources import TDirectDataItem, TDataItem from dlt.common.schema import utils, TSchemaUpdate from dlt.common.storages import NormalizeStorage, DataItemStorage -from dlt.common.configuration import NormalizeVolumeConfiguration +from dlt.common.configuration.specs import NormalizeVolumeConfiguration from experiments.pipeline.pipe import PipeIterator diff --git a/experiments/pipeline/pipe.py b/experiments/pipeline/pipe.py index 810b542363..5b874ce570 100644 --- a/experiments/pipeline/pipe.py +++ b/experiments/pipeline/pipe.py @@ -7,6 +7,7 @@ from typing import Optional, Sequence, Union, Callable, Iterable, Iterator, List, NamedTuple, Awaitable, Tuple, Type, TYPE_CHECKING from dlt.common.typing import TDataItem +from dlt.common.sources import TDirectDataItem, TResolvableDataItem if TYPE_CHECKING: TItemFuture = Future[TDirectDataItem] @@ -15,7 +16,6 @@ from dlt.common.exceptions import DltException from dlt.common.time import sleep -from dlt.common.sources import TDirectDataItem, TResolvableDataItem class PipeItem(NamedTuple): diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py index e47cecf986..1513929e1b 100644 --- a/experiments/pipeline/pipeline.py +++ b/experiments/pipeline/pipeline.py @@ -18,7 +18,7 @@ from dlt.common.configuration import make_configuration, RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, ProductionNormalizeVolumeConfiguration from dlt.common.schema.schema import Schema -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import is_interactive, uniq_id from dlt.extract.extractor_storage import ExtractorStorageBase diff --git a/tests/common/normalizers/test_json_relational.py b/tests/common/normalizers/test_json_relational.py index 47a383b778..a9dbe5a2ba 100644 --- a/tests/common/normalizers/test_json_relational.py +++ b/tests/common/normalizers/test_json_relational.py @@ -507,6 +507,19 @@ def test_preserves_complex_types_list(schema: Schema) -> None: assert root_row[1]["value"] == row["value"] +def test_wrap_in_dict(schema: Schema) -> None: + # json normalizer wraps in dict + row = list(schema.normalize_data_item(schema, 1, "load_id", "simplex"))[0][1] + assert row["value"] == 1 + assert row["_dlt_load_id"] == "load_id" + # wrap a list + rows = list(schema.normalize_data_item(schema, [1, 2, 3, 4, "A"], "load_id", "listex")) + assert len(rows) == 6 + assert rows[0][0] == ("listex", None,) + assert rows[1][0] == ("listex__value", "listex") + assert rows[-1][1]["value"] == "A" + + def test_complex_types_for_recursion_level(schema: Schema) -> None: add_dlt_root_id_propagation(schema) # if max recursion depth is set, nested elements will be kept as complex diff --git a/tests/common/storages/test_file_storage.py b/tests/common/storages/test_file_storage.py index 46e2bcc653..b978670b2a 100644 --- a/tests/common/storages/test_file_storage.py +++ b/tests/common/storages/test_file_storage.py @@ -1,7 +1,7 @@ import os import pytest -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import encoding_for_mode, uniq_id from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, test_storage diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index e73d07987d..31ead5021d 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -5,14 +5,13 @@ from dlt.common import json from dlt.common.typing import DictStrAny -from dlt.common.file_storage import FileStorage from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema from dlt.common.schema.utils import default_normalizers from dlt.common.configuration import make_configuration from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError -from dlt.common.storages import SchemaStorage, LiveSchemaStorage +from dlt.common.storages import SchemaStorage, LiveSchemaStorage, FileStorage from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT from tests.common.utils import load_yml_case, yml_case_path diff --git a/tests/common/storages/test_versioned_storage.py b/tests/common/storages/test_versioned_storage.py index c7c2236cc5..ff23480a48 100644 --- a/tests/common/storages/test_versioned_storage.py +++ b/tests/common/storages/test_versioned_storage.py @@ -1,7 +1,7 @@ import pytest import semver -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException from dlt.common.storages.versioned_storage import VersionedStorage diff --git a/tests/common/test_configuration.py b/tests/common/test_configuration.py index 197b8286f9..47a8354ed8 100644 --- a/tests/common/test_configuration.py +++ b/tests/common/test_configuration.py @@ -304,10 +304,10 @@ def test_configuration_is_mutable_mapping(environment: Any) -> None: assert C._version != "1.1.1" # delete is not supported - with pytest.raises(NotImplementedError): + with pytest.raises(KeyError): del C["pipeline_name"] - with pytest.raises(NotImplementedError): + with pytest.raises(KeyError): C.pop("pipeline_name", None) # setting supported diff --git a/tests/common/test_logging.py b/tests/common/test_logging.py index 7c6cf9ef11..8e3a4bcd42 100644 --- a/tests/common/test_logging.py +++ b/tests/common/test_logging.py @@ -32,8 +32,6 @@ class SentryLoggerConfiguration(JsonLoggerConfiguration): sentry_dsn: str = "http://user:pass@localhost/818782" -import dataclasses -@dataclasses.dataclass @configspec(init=True) class SentryLoggerCriticalConfiguration(SentryLoggerConfiguration): log_level: str = "CRITICAL" diff --git a/tests/dbt_runner/test_runner_redshift.py b/tests/dbt_runner/test_runner_redshift.py index 455d893788..d7e4d42b15 100644 --- a/tests/dbt_runner/test_runner_redshift.py +++ b/tests/dbt_runner/test_runner_redshift.py @@ -6,7 +6,7 @@ from dlt.common import logger from dlt.common.configuration import make_configuration from dlt.common.configuration.specs import PostgresCredentials -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.telemetry import TRunMetrics, get_metrics_from_prometheus from dlt.common.typing import StrStr from dlt.common.utils import uniq_id, with_custom_environ diff --git a/tests/dbt_runner/test_utils.py b/tests/dbt_runner/test_utils.py index 2ea3191930..162d5fd20a 100644 --- a/tests/dbt_runner/test_utils.py +++ b/tests/dbt_runner/test_utils.py @@ -3,7 +3,7 @@ from git import GitCommandError, Repo, RepositoryDirtyError import pytest -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.dbt_runner.utils import DBTProcessingError, clone_repo, ensure_remote_head, git_custom_key_command, initialize_dbt_logging, run_dbt_command diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index 8a6a3adfe4..e4ef6dcf70 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -4,7 +4,7 @@ from dlt.common import json, pendulum, Decimal from dlt.common.arithmetics import numeric_default_context -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id from dlt.load.exceptions import LoadJobNotExistsException, LoadJobServerTerminalException diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index 020bb42a52..a5353b3253 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -4,7 +4,7 @@ from dlt.common import pendulum, Decimal from dlt.common.arithmetics import numeric_default_context -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id diff --git a/tests/load/test_client.py b/tests/load/test_client.py index 359d44a24d..d3d66d60eb 100644 --- a/tests/load/test_client.py +++ b/tests/load/test_client.py @@ -6,7 +6,7 @@ from dlt.common import json, pendulum from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema import TTableSchemaColumns from dlt.common.utils import uniq_id diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 404e5b46bf..7a7979c5bf 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -7,10 +7,10 @@ from unittest.mock import patch from prometheus_client import CollectorRegistry -from dlt.common.file_storage import FileStorage from dlt.common.exceptions import TerminalException, TerminalValueError from dlt.common.schema import Schema -from dlt.common.storages.load_storage import JobWithUnsupportedWriterException, LoadStorage +from dlt.common.storages import FileStorage, LoadStorage +from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.typing import StrAny from dlt.common.utils import uniq_id from dlt.load.client_base import JobClientBase, LoadEmptyJob, LoadJob diff --git a/tests/load/utils.py b/tests/load/utils.py index 367ab55e34..be067595ed 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -6,9 +6,8 @@ from dlt.common.configuration import make_configuration from dlt.common.configuration.specs.schema_volume_configuration import SchemaVolumeConfiguration from dlt.common.data_writers import DataWriter -from dlt.common.file_storage import FileStorage from dlt.common.schema import TColumnSchema, TTableSchemaColumns -from dlt.common.storages.schema_storage import SchemaStorage +from dlt.common.storages import SchemaStorage, FileStorage from dlt.common.schema.utils import new_table from dlt.common.time import sleep from dlt.common.typing import StrAny diff --git a/tests/utils.py b/tests/utils.py index 481f2eb350..c3a769a4c7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,7 +9,7 @@ from dlt.common.configuration.resolve import _get_resolvable_fields, make_configuration from dlt.common.configuration.specs import RunConfiguration from dlt.common.logger import init_logging_from_config -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema import Schema from dlt.common.storages.versioned_storage import VersionedStorage from dlt.common.typing import StrAny From 80233c038b143588558e2b18b75ac2499759055e Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Oct 2022 18:38:06 +0200 Subject: [PATCH 24/66] data item may be of any type, not dict loaded from json --- dlt/common/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 69a0ea941f..856a1dc976 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -18,7 +18,7 @@ TFun = TypeVar("TFun", bound=Callable[..., Any]) TAny = TypeVar("TAny", bound=Any) TSecretValue = NewType("TSecretValue", str) # represent secret value ie. coming from Kubernetes/Docker secrets or other providers -TDataItem = DictStrAny +TDataItem = Any # a single data item extracted from data source, normalized and loaded TVariantBase = TypeVar("TVariantBase", covariant=True) From f7b5434244490030c6cef4c94c5eda743d680dde Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 5 Oct 2022 22:35:20 +0200 Subject: [PATCH 25/66] adds resolving configs via providers with variable namespaces, provides lookup traces --- dlt/common/configuration/__init__.py | 2 +- dlt/common/configuration/exceptions.py | 57 ++-- dlt/common/configuration/inject.py | 25 ++ .../configuration/providers/__init__.py | 2 + .../configuration/providers/configuration.py | 17 ++ dlt/common/configuration/providers/environ.py | 79 +++--- .../configuration/providers/provider.py | 22 ++ dlt/common/configuration/resolve.py | 220 +++++++++------ .../configuration/specs/base_configuration.py | 51 +++- .../specs/gcp_client_credentials.py | 10 +- dlt/common/logger.py | 12 +- dlt/common/schema/utils.py | 8 +- dlt/common/typing.py | 43 ++- dlt/dbt_runner/utils.py | 2 +- dlt/load/bigquery/client.py | 2 +- experiments/pipeline/configuration.py | 4 +- experiments/pipeline/pipe.py | 6 +- tests/cases.py | 6 +- tests/common/schema/test_coercion.py | 16 +- tests/common/test_configuration.py | 254 ++++++++++++------ tests/common/test_typing.py | 29 +- tests/utils.py | 23 +- 22 files changed, 630 insertions(+), 260 deletions(-) create mode 100644 dlt/common/configuration/inject.py create mode 100644 dlt/common/configuration/providers/configuration.py create mode 100644 dlt/common/configuration/providers/provider.py diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 30f9a4a0d2..0459774597 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,4 +1,4 @@ -from .specs.base_configuration import configspec # noqa: F401 +from .specs.base_configuration import configspec, is_valid_hint # noqa: F401 from .resolve import make_configuration # noqa: F401 from .exceptions import ( # noqa: F401 diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index d15de37a33..ce13c949bb 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -1,8 +1,15 @@ -from typing import Any, Iterable, Type, Union +from typing import Any, Iterable, Mapping, Type, Union, NamedTuple, Sequence from dlt.common.exceptions import DltException +class LookupTrace(NamedTuple): + provider: str + namespaces: Sequence[str] + key: str + value: Any + + class ConfigurationException(DltException): def __init__(self, msg: str) -> None: super().__init__(msg) @@ -16,24 +23,26 @@ def __init__(self, _typ: type) -> None: class ConfigEntryMissingException(ConfigurationException): """thrown when not all required config elements are present""" - def __init__(self, missing_set: Iterable[str], namespace: str = None) -> None: - self.missing_set = missing_set - self.namespace = namespace + def __init__(self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]]) -> None: + self.traces = traces + self.spec_name = spec_name - msg = 'Missing config keys: ' + str(missing_set) - if namespace: - msg += ". Note that required namespace for that keys is " + namespace + " and namespace separator is two underscores" + msg = f"Following fields are missing: {str(list(traces.keys()))} in configuration with spec {spec_name}\n" + for f, traces in traces.items(): + msg += f'\tfor field "{f}" config providers and keys were tried in following order\n' + for tr in traces: + msg += f'\t\tIn {tr.provider} key {tr.key} was not found.\n' super().__init__(msg) class ConfigEnvValueCannotBeCoercedException(ConfigurationException): """thrown when value from ENV cannot be coerced to hinted type""" - def __init__(self, attr_name: str, env_value: str, hint: type) -> None: - self.attr_name = attr_name - self.env_value = env_value + def __init__(self, field_name: str, field_value: Any, hint: type) -> None: + self.field_name = field_name + self.field_value = field_value self.hint = hint - super().__init__('env value %s cannot be coerced into type %s in attr %s' % (env_value, str(hint), attr_name)) + super().__init__('env value %s cannot be coerced into type %s in attr %s' % (field_value, str(hint), field_name)) class ConfigIntegrityException(ConfigurationException): @@ -53,10 +62,26 @@ def __init__(self, path: str) -> None: super().__init__(f"Missing config file in {path}") -class ConfigFieldMissingAnnotationException(ConfigurationException): - """thrown when configuration specification does not have type annotation""" +class ConfigFieldMissingTypeHintException(ConfigurationException): + """thrown when configuration specification does not have type hint""" - def __init__(self, field_name: str, typ_: Type[Any]) -> None: + def __init__(self, field_name: str, spec: Type[Any]) -> None: self.field_name = field_name - self.typ_ = typ_ - super().__init__(f"Field {field_name} on configspec {typ_} does not provide required type annotation") + self.typ_ = spec + super().__init__(f"Field {field_name} on configspec {spec} does not provide required type hint") + + +class ConfigFieldTypeHintNotSupported(ConfigurationException): + """thrown when configuration specification uses not supported type in hint""" + + def __init__(self, field_name: str, spec: Type[Any], typ_: Type[Any]) -> None: + self.field_name = field_name + self.typ_ = spec + super().__init__(f"Field {field_name} on configspec {spec} has hint with unsupported type {typ_}") + + +class ValueNotSecretException(ConfigurationException): + def __init__(self, provider_name: str, key: str) -> None: + self.provider_name = provider_name + self.key = key + super().__init__(f"Provider {provider_name} cannot hold secret values but key {key} with secret value is present") diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py new file mode 100644 index 0000000000..55f0b888e6 --- /dev/null +++ b/dlt/common/configuration/inject.py @@ -0,0 +1,25 @@ +from typing import Dict, Type, TypeVar + +from dlt.common.configuration.specs.base_configuration import BaseConfiguration + +TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) + + +class Container: + + _INSTANCE: "Container" = None + + def __new__(cls: Type["Container"]) -> "Container": + if not cls._INSTANCE: + cls._INSTANCE = super().__new__(cls) + return cls._INSTANCE + + + def __init__(self) -> None: + self.configurations: Dict[Type[BaseConfiguration], BaseConfiguration] = {} + + + def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration: + # return existing config object or create it from spec + return self.configurations.setdefault(spec, spec()) + diff --git a/dlt/common/configuration/providers/__init__.py b/dlt/common/configuration/providers/__init__.py index e69de29bb2..10d7b9b24a 100644 --- a/dlt/common/configuration/providers/__init__.py +++ b/dlt/common/configuration/providers/__init__.py @@ -0,0 +1,2 @@ +from .provider import Provider +from .environ import EnvironProvider \ No newline at end of file diff --git a/dlt/common/configuration/providers/configuration.py b/dlt/common/configuration/providers/configuration.py new file mode 100644 index 0000000000..cdfbaaac1f --- /dev/null +++ b/dlt/common/configuration/providers/configuration.py @@ -0,0 +1,17 @@ + + +from typing import List + +from dlt.common.configuration.providers import Provider +from dlt.common.configuration.providers.environ import EnvironProvider +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec + + +@configspec +class ConfigProvidersConfiguration(BaseConfiguration): + providers: List[Provider] + + def __init__(self) -> None: + super().__init__() + # add default providers + self.providers = [EnvironProvider()] diff --git a/dlt/common/configuration/providers/environ.py b/dlt/common/configuration/providers/environ.py index c5364f5973..02278cb057 100644 --- a/dlt/common/configuration/providers/environ.py +++ b/dlt/common/configuration/providers/environ.py @@ -1,41 +1,54 @@ from os import environ from os.path import isdir -from typing import Any, Optional, Type +from typing import Any, Optional, Type, Tuple from dlt.common.typing import TSecretValue +from .provider import Provider + SECRET_STORAGE_PATH: str = "/run/secrets/%s" +class EnvironProvider(Provider): + + @staticmethod + def get_key_name(key: str, *namespaces: str) -> str: + # env key is always upper case + if namespaces: + namespaces = filter(lambda x: bool(x), namespaces) + env_key = "__".join((*namespaces, key)) + else: + env_key = key + return env_key.upper() + + @property + def name(self) -> str: + return "Environment Variables" + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + # apply namespace to the key + key = self.get_key_name(key, *namespaces) + if hint is TSecretValue: + # try secret storage + try: + # must conform to RFC1123 + secret_name = key.lower().replace("_", "-") + secret_path = SECRET_STORAGE_PATH % secret_name + # kubernetes stores secrets as files in a dir, docker compose plainly + if isdir(secret_path): + secret_path += "/" + secret_name + with open(secret_path, "r", encoding="utf-8") as f: + secret = f.read() + # add secret to environ so forks have access + # TODO: removing new lines is not always good. for password OK for PEMs not + # TODO: in regular secrets that is dealt with in particular configuration logic + environ[key] = secret.strip() + # do not strip returned secret + return secret, key + # includes FileNotFound + except OSError: + pass + return environ.get(key, None), key -def get_key_name(key: str, namespace: str = None) -> str: - # env key is always upper case - if namespace: - env_key = namespace + "__" + key - else: - env_key = key - return env_key.upper() - -def get_key(key: str, hint: Type[Any], namespace: str = None) -> Optional[str]: - # apply namespace to the key - key = get_key_name(key, namespace) - if hint is TSecretValue: - # try secret storage - try: - # must conform to RFC1123 - secret_name = key.lower().replace("_", "-") - secret_path = SECRET_STORAGE_PATH % secret_name - # kubernetes stores secrets as files in a dir, docker compose plainly - if isdir(secret_path): - secret_path += "/" + secret_name - with open(secret_path, "r", encoding="utf-8") as f: - secret = f.read() - # add secret to environ so forks have access - # TODO: removing new lines is not always good. for password OK for PEMs not - # TODO: in regular secrets that is dealt with in particular configuration logic - environ[key] = secret.strip() - # do not strip returned secret - return secret - # includes FileNotFound - except OSError: - pass - return environ.get(key, None) \ No newline at end of file + @property + def is_secret(self) -> bool: + return True diff --git a/dlt/common/configuration/providers/provider.py b/dlt/common/configuration/providers/provider.py new file mode 100644 index 0000000000..5257221560 --- /dev/null +++ b/dlt/common/configuration/providers/provider.py @@ -0,0 +1,22 @@ +import abc +from typing import Any, Tuple, Type, Optional + + + +class Provider(abc.ABC): + def __init__(self) -> None: + pass + + @abc.abstractmethod + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + pass + + @property + @abc.abstractmethod + def is_secret(self) -> bool: + pass + + @property + @abc.abstractmethod + def name(self) -> str: + pass diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 36f8f7290b..510d7f4e24 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -1,21 +1,21 @@ +import ast import dataclasses import inspect import sys import semver -from typing import Any, Dict, List, Mapping, Type, TypeVar - -from dlt.common.typing import is_optional_type, is_literal_type -from dlt.common.configuration.specs.base_configuration import BaseConfiguration -from dlt.common.configuration.providers import environ -from dlt.common.configuration.exceptions import (ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException) - -SIMPLE_TYPES: List[Any] = [int, bool, list, dict, tuple, bytes, set, float] -# those types and Optionals of those types should not be passed to eval function -NON_EVAL_TYPES = [str, None, Any] -# allows to coerce (type1 from type2) -ALLOWED_TYPE_COERCIONS = [(float, int), (str, int), (str, float)] -CHECK_INTEGRITY_F: str = "check_integrity" +from collections.abc import Mapping as C_Mapping +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, get_origin + +from dlt.common import json, logger +from dlt.common.typing import TSecretValue, is_optional_type, extract_inner_type +from dlt.common.schema.utils import coerce_type, py_type_to_sc_type +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration, configspec +from dlt.common.configuration.inject import Container +from dlt.common.configuration.providers.configuration import ConfigProvidersConfiguration +from dlt.common.configuration.exceptions import (LookupTrace, ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException, ValueNotSecretException) + +CHECK_INTEGRITY_F: str = "check_integrity" TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) @@ -23,23 +23,23 @@ def make_configuration(config: TConfiguration, initial_value: Any = None, accept if not isinstance(config, BaseConfiguration): raise ConfigurationWrongTypeException(type(config)) - # get fields to resolve as per dataclasses PEP - fields = _get_resolvable_fields(config) # parse initial value if possible if initial_value is not None: try: config.from_native_representation(initial_value) except (NotImplementedError, ValueError): # if parsing failed and initial_values is dict then apply - if isinstance(initial_value, Mapping): + # TODO: we may try to parse with json here if str + if isinstance(initial_value, C_Mapping): config.update(initial_value) + else: + raise InvalidInitialValue(type(config), type(initial_value)) - _resolve_config_fields(config, fields, accept_partial) try: - _is_config_bounded(config, fields) + _resolve_config_fields(config, accept_partial) _check_configuration_integrity(config) # full configuration was resolved - config.__is_partial__ = False + config.__is_resolved__ = True except ConfigEntryMissingException: if not accept_partial: raise @@ -48,6 +48,45 @@ def make_configuration(config: TConfiguration, initial_value: Any = None, accept return config +def deserialize_value(key: str, value: Any, hint: Type[Any]) -> Any: + try: + if hint != Any: + hint_dt = py_type_to_sc_type(hint) + value_dt = py_type_to_sc_type(type(value)) + + # eval only if value is string and hint is "complex" + if value_dt == "text" and hint_dt == "complex": + if hint is tuple: + # use literal eval for tuples + value = ast.literal_eval(value) + else: + # use json for sequences and mappings + value = json.loads(value) + # exact types must match + if not isinstance(value, hint): + raise ValueError(value) + else: + # for types that are not complex, reuse schema coercion rules + if value_dt != hint_dt: + value = coerce_type(hint_dt, value_dt, value) + return value + except ConfigEnvValueCannotBeCoercedException: + raise + except Exception as exc: + raise ConfigEnvValueCannotBeCoercedException(key, value, hint) from exc + + +def serialize_value(value: Any) -> Any: + if value is None: + raise ValueError(value) + # return literal for tuples + if isinstance(value, tuple): + return str(value) + # coerce type to text which will use json for mapping and sequences + value_dt = py_type_to_sc_type(type(value)) + return coerce_type("text", value_dt, value) + + def _add_module_version(config: BaseConfiguration) -> None: try: v = sys._getframe(1).f_back.f_globals["__version__"] @@ -57,61 +96,59 @@ def _add_module_version(config: BaseConfiguration) -> None: pass -def _resolve_config_fields(config: BaseConfiguration, fields: Mapping[str, type], accept_partial: bool) -> None: +def _resolve_config_fields(config: BaseConfiguration, accept_partial: bool) -> None: + fields = config.get_resolvable_fields() + unresolved_fields: Dict[str, Sequence[LookupTrace]] = {} + for key, hint in fields.items(): # get default value - resolved_value = getattr(config, key, None) - # resolve key value via active providers - value = environ.get_key(key, hint, config.__namespace__) - - # extract hint from Optional / NewType hints - hint = _extract_simple_type(hint) - # if hint is BaseConfiguration then resolve it recursively - if inspect.isclass(hint) and issubclass(hint, BaseConfiguration): - if isinstance(resolved_value, BaseConfiguration): - # if actual value is BaseConfiguration, resolve that instance - resolved_value = make_configuration(resolved_value, accept_partial=accept_partial) - else: - # create new instance and pass value from the provider as initial - resolved_value = make_configuration(hint(), initial_value=value or resolved_value, accept_partial=accept_partial) + current_value = getattr(config, key, None) + # check if hint optional + is_optional = is_optional_type(hint) + # accept partial becomes True if type if optional so we do not fail on optional configs that do not resolve fully + accept_partial = accept_partial or is_optional + # if actual value is BaseConfiguration, resolve that instance + if isinstance(current_value, BaseConfiguration): + current_value = make_configuration(current_value, accept_partial=accept_partial) else: - if value is not None: - resolved_value = _coerce_single_value(key, value, hint) - # set value resolved value - setattr(config, key, resolved_value) - - -def _coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: - try: - if hint not in NON_EVAL_TYPES: - # create primitive types out of strings - typed_value = eval(value) # nosec - # for primitive types check coercion - if hint in SIMPLE_TYPES and type(typed_value) != hint: - # allow some exceptions - coerce_exception = next( - (e for e in ALLOWED_TYPE_COERCIONS if e == (hint, type(typed_value))), None) - if coerce_exception: - return hint(typed_value) - else: - raise ConfigEnvValueCannotBeCoercedException(key, typed_value, hint) - return typed_value - else: - return value - except ConfigEnvValueCannotBeCoercedException: - raise - except Exception as exc: - raise ConfigEnvValueCannotBeCoercedException(key, value, hint) from exc + # resolve key value via active providers + value, traces = _resolve_single_field(key, hint, config.__namespace__) + + # log trace + if logger.is_logging() and logger.log_level() == "DEBUG": + logger.debug(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") + for tr in traces: + logger.debug(str(tr)) + + # extract hint from Optional / Literal / NewType hints + hint = extract_inner_type(hint) + # extract origin from generic types + hint = get_origin(hint) or hint + # if hint is BaseConfiguration then resolve it recursively + if inspect.isclass(hint) and issubclass(hint, BaseConfiguration): + # create new instance and pass value from the provider as initial + current_value = make_configuration(hint(), initial_value=value or current_value, accept_partial=accept_partial) + else: + if value is not None: + current_value = deserialize_value(key, value, hint) + # collect unresolved fields + if not is_optional and current_value is None: + unresolved_fields[key] = traces + # set resolved value in config + setattr(config, key, current_value) + if unresolved_fields: + raise ConfigEntryMissingException(type(config).__name__, unresolved_fields) -def _is_config_bounded(config: BaseConfiguration, fields: Mapping[str, type]) -> None: - # TODO: here we assume all keys are taken from environ provider, that should change when we introduce more providers - _unbound_attrs = [ - environ.get_key_name(key, config.__namespace__) for key in fields if getattr(config, key) is None and not is_optional_type(fields[key]) - ] +# def _is_config_bounded(config: BaseConfiguration, fields: Mapping[str, type]) -> None: +# # TODO: here we assume all keys are taken from environ provider, that should change when we introduce more providers +# # environ.get_key_name(key, config.__namespace__) +# _unbound_attrs = [ +# key for key in fields if getattr(config, key) is None and not is_optional_type(fields[key]) +# ] - if len(_unbound_attrs) > 0: - raise ConfigEntryMissingException(_unbound_attrs, config.__namespace__) +# if len(_unbound_attrs) > 0: +# raise ConfigEntryMissingException(_unbound_attrs, config.__namespace__) def _check_configuration_integrity(config: BaseConfiguration) -> None: @@ -132,15 +169,38 @@ def _get_resolvable_fields(config: BaseConfiguration) -> Dict[str, type]: return {f.name:f.type for f in dataclasses.fields(config) if not f.name.startswith("__")} -def _extract_simple_type(hint: Type[Any]) -> Type[Any]: - # extract optional type and call recursively - if is_literal_type(hint): - # assume that all literals are of the same type - return _extract_simple_type(type(hint.__args__[0])) - if is_optional_type(hint): - # todo: use `get_args` in python 3.8 - return _extract_simple_type(hint.__args__[0]) - if not hasattr(hint, "__supertype__"): - return hint - # descend into supertypes of NewType - return _extract_simple_type(hint.__supertype__) +@configspec +class ConfigNamespacesConfiguration(BaseConfiguration): + pipeline_name: Optional[str] + namespaces: List[str] + + def __init__(self) -> None: + super().__init__() + self.namespaces = [] + + +def _resolve_single_field(key: str, hint: Type[Any], namespace: str, *namespaces: str) -> Tuple[Optional[Any], List[LookupTrace]]: + # get providers from container + providers = Container()[ConfigProvidersConfiguration].providers + # get additional namespaces to look in from container + context_namespaces = Container()[ConfigNamespacesConfiguration].namespaces + + # start looking from the top provider with most specific set of namespaces first + traces: List[LookupTrace] = [] + value = None + ns = [*namespaces, *context_namespaces] + for provider in providers: + while True: + # first namespace always present + _ns_t = (namespace, *ns) if namespace else ns + value, ns_key = provider.get_value(key, hint, *_ns_t) + # create trace + traces.append(LookupTrace(provider.name, _ns_t, ns_key, value)) + # if secret is obtained from non secret provider, we must fail + if value is not None and not provider.is_secret and (hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration))): + raise ValueNotSecretException(provider.name, ns_key) + if len(ns) == 0 or value is not None: + break + ns.pop() + + return value, traces diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 75ccf9a600..57ed3c78c5 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -1,13 +1,28 @@ +import contextlib import dataclasses -from typing import Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING +from typing import Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING, get_origin if TYPE_CHECKING: TDtcField = dataclasses.Field[Any] else: TDtcField = dataclasses.Field -from dlt.common.typing import TAny -from dlt.common.configuration.exceptions import ConfigFieldMissingAnnotationException +from dlt.common.typing import TAny, extract_inner_type, is_optional_type +from dlt.common.schema.utils import py_type_to_sc_type +from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported + + +def is_valid_hint(hint: Type[Any]) -> bool: + hint = extract_inner_type(hint) + hint = get_origin(hint) or hint + if hint is Any: + return True + if issubclass(hint, BaseConfiguration): + return True + with contextlib.suppress(TypeError): + py_type_to_sc_type(hint) + return True + return False def configspec(cls: Type[TAny] = None, /, *, init: bool = False) -> Type[TAny]: @@ -19,10 +34,12 @@ def wrap(cls: Type[TAny]) -> Type[TAny]: setattr(cls, ann, None) # get all attributes without corresponding annotations for att_name, att in cls.__dict__.items(): - if not callable(att) and not att_name.startswith(("__", "_abc_impl")) and att_name not in cls.__annotations__: - print(att) - print(callable(att)) - raise ConfigFieldMissingAnnotationException(att_name, cls) + if not callable(att) and not att_name.startswith(("__", "_abc_impl")): + if att_name not in cls.__annotations__: + raise ConfigFieldMissingTypeHintException(att_name, cls) + hint = cls.__annotations__[att_name] + if not is_valid_hint(hint): + raise ConfigFieldTypeHintNotSupported(att_name, cls, hint) return dataclasses.dataclass(cls, init=init, eq=False) # type: ignore # called with parenthesis @@ -35,8 +52,8 @@ def wrap(cls: Type[TAny]) -> Type[TAny]: @configspec class BaseConfiguration(MutableMapping[str, Any]): - # will be set to true if not all config entries could be resolved - __is_partial__: bool = dataclasses.field(default = True, init=False, repr=False) + # true when all config fields were resolved and have a specified value type + __is_resolved__: bool = dataclasses.field(default = False, init=False, repr=False) # namespace used by config providers when searching for keys __namespace__: str = dataclasses.field(default = None, init=False, repr=False) @@ -67,6 +84,22 @@ def to_native_representation(self) -> Any: """ raise ValueError() + def get_resolvable_fields(self) -> Dict[str, type]: + """Returns a mapping of fields to their type hints. Dunder should not be resolved and are not returned""" + return {f.name:f.type for f in self.__fields_dict().values() if not f.name.startswith("__")} + + def is_resolved(self) -> bool: + return self.__is_resolved__ + + def is_partial(self) -> bool: + """Returns True when any required resolvable field has its value missing.""" + if self.__is_resolved__: + return False + # check if all resolvable fields have value + return any( + field for field, hint in self.get_resolvable_fields().items() if getattr(self, field) is None and not is_optional_type(hint) + ) + # implement dictionary-compatible interface on top of dataclass def __getitem__(self, __key: str) -> Any: diff --git a/dlt/common/configuration/specs/gcp_client_credentials.py b/dlt/common/configuration/specs/gcp_client_credentials.py index d51921d7a4..7d9ffd7695 100644 --- a/dlt/common/configuration/specs/gcp_client_credentials.py +++ b/dlt/common/configuration/specs/gcp_client_credentials.py @@ -21,14 +21,14 @@ class GcpClientCredentials(CredentialsConfiguration): file_upload_timeout: float = 30 * 60.0 retry_deadline: float = 600 # how long to retry the operation in case of error, the backoff 60s - def from_native_representation(self, initial_value: Any) -> None: - if not isinstance(initial_value, str): - raise ValueError(initial_value) + def from_native_representation(self, native_value: Any) -> None: + if not isinstance(native_value, str): + raise ValueError(native_value) try: - service_dict = json.loads(initial_value) + service_dict = json.loads(native_value) self.update(service_dict) except Exception: - raise ValueError(initial_value) + raise ValueError(native_value) def check_integrity(self) -> None: if self.private_key and self.private_key[-1] != "\n": diff --git a/dlt/common/logger.py b/dlt/common/logger.py index 53924c1e41..2d55636f43 100644 --- a/dlt/common/logger.py +++ b/dlt/common/logger.py @@ -14,7 +14,7 @@ from dlt._version import common_version as __version__ -DLT_LOGGER_NAME = "sv-dlt" +DLT_LOGGER_NAME = "dlt" LOGGER: Logger = None @@ -211,6 +211,16 @@ def init_logging_from_config(C: RunConfiguration) -> None: _init_sentry(C, version) +def is_logging() -> bool: + return LOGGER is not None + + +def log_level() -> str: + if not LOGGER: + raise RuntimeError("Logger not initialized") + return logging.getLevelName(LOGGER.level) + + def is_json_logging(log_format: str) -> bool: return log_format == "JSON" diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 040337e4ef..c722134689 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -5,6 +5,7 @@ import datetime # noqa: I251 import contextlib from copy import deepcopy +from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from typing import Dict, List, Sequence, Tuple, Type, Any, cast from dlt.common import pendulum, json, Decimal, Wei @@ -287,7 +288,8 @@ def py_type_to_sc_type(t: Type[Any]) -> TDataType: return "wei" if issubclass(t, Decimal): return "decimal" - if issubclass(t, datetime.datetime): + # TODO: implement new "date" type, currently assign "datetime" + if issubclass(t, (datetime.datetime, datetime.date)): return "timestamp" # check again for subclassed basic types @@ -299,8 +301,10 @@ def py_type_to_sc_type(t: Type[Any]) -> TDataType: return "bigint" if issubclass(t, bytes): return "binary" + if issubclass(t, (C_Mapping, C_Sequence)): + return "complex" - return "text" + raise TypeError(t) def coerce_type(to_type: TDataType, from_type: TDataType, value: Any) -> Any: diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 856a1dc976..aef5b9f245 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -37,10 +37,7 @@ def __call__(self) -> Union[TVariantBase, TVariantRV]: def is_optional_type(t: Type[Any]) -> bool: - # todo: use typing get_args and get_origin in python 3.8 - if hasattr(t, "__origin__"): - return t.__origin__ is Union and type(None) in t.__args__ - return False + return get_origin(t) is Union and type(None) in get_args(t) def extract_optional_type(t: Type[Any]) -> Any: @@ -48,7 +45,11 @@ def extract_optional_type(t: Type[Any]) -> Any: def is_literal_type(hint: Type[Any]) -> bool: - return hasattr(hint, "__origin__") and hint.__origin__ is Literal + return get_origin(hint) is Literal + + +def is_newtype_type(t: Type[Any]) -> bool: + return hasattr(t, "__supertype__") def is_typeddict(t: Any) -> bool: @@ -57,15 +58,35 @@ def is_typeddict(t: Any) -> bool: def is_list_generic_type(t: Any) -> bool: try: - o = get_origin(t) - return issubclass(o, list) or issubclass(o, C_Sequence) - except Exception: + return issubclass(get_origin(t), C_Sequence) + except TypeError: return False def is_dict_generic_type(t: Any) -> bool: try: - o = get_origin(t) - return issubclass(o, dict) or issubclass(o, C_Mapping) - except Exception: + return issubclass(get_origin(t), C_Mapping) + except TypeError: return False + + +def extract_inner_type(hint: Type[Any]) -> Type[Any]: + """Gets the inner type from Literal, Optional and NewType + + Args: + hint (Type[Any]): Any type + + Returns: + Type[Any]: Inner type if hint was Literal, Optional or NewType, otherwise hint + """ + if is_literal_type(hint): + # assume that all literals are of the same type + return extract_inner_type(type(get_args(hint)[0])) + if is_optional_type(hint): + # extract optional type and call recursively + return extract_inner_type(get_args(hint)[0]) + if is_newtype_type(hint): + # descend into supertypes of NewType + return extract_inner_type(hint.__supertype__) + return hint + diff --git a/dlt/dbt_runner/utils.py b/dlt/dbt_runner/utils.py index 2d426bd727..51a07d336a 100644 --- a/dlt/dbt_runner/utils.py +++ b/dlt/dbt_runner/utils.py @@ -50,7 +50,7 @@ def git_custom_key_command(private_key: Optional[str]) -> Iterator[str]: def ensure_remote_head(repo_path: str, with_git_command: Optional[str] = None) -> None: # update remotes and check if heads are same. ignores locally modified files repo = Repo(repo_path) - # use custom environemnt if specified + # use custom environment if specified with repo.git.custom_environment(GIT_SSH_COMMAND=with_git_command): # update origin repo.remote().update() diff --git a/dlt/load/bigquery/client.py b/dlt/load/bigquery/client.py index 8340706a90..08a33be155 100644 --- a/dlt/load/bigquery/client.py +++ b/dlt/load/bigquery/client.py @@ -62,7 +62,7 @@ def __init__(self, default_dataset_name: str, CREDENTIALS: GcpClientCredentials) def open_connection(self) -> None: # use default credentials if partial config - if self.C.__is_partial__: + if not self.C.is_resolved(): credentials = None else: credentials = service_account.Credentials.from_service_account_info(self.C.to_native_representation()) diff --git a/experiments/pipeline/configuration.py b/experiments/pipeline/configuration.py index b5d9fc933e..56697f4538 100644 --- a/experiments/pipeline/configuration.py +++ b/experiments/pipeline/configuration.py @@ -8,8 +8,8 @@ from functools import wraps from dlt.common.typing import DictStrAny, StrAny, TAny, TFun +from dlt.common.configuration import make_configuration, is_valid_hint from dlt.common.configuration.specs import BaseConfiguration -from dlt.common.configuration.resolve import NON_EVAL_TYPES, make_configuration, SIMPLE_TYPES # _POS_PARAMETER_KINDS = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, Parameter.VAR_POSITIONAL) @@ -44,7 +44,7 @@ def spec_from_signature(name: str, sig: Signature) -> Type[BaseConfiguration]: # skip *args and **kwargs if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in ["self", "cls"]: field_type = Any if p.annotation == Parameter.empty else p.annotation - if field_type in SIMPLE_TYPES or field_type in NON_EVAL_TYPES or issubclass(field_type, BaseConfiguration): + if is_valid_hint(field_type): field_default = None if p.default == Parameter.empty else dataclasses.field(default=p.default) if field_default: # correct the type if Any diff --git a/experiments/pipeline/pipe.py b/experiments/pipeline/pipe.py index 5b874ce570..9b0a43dac0 100644 --- a/experiments/pipeline/pipe.py +++ b/experiments/pipeline/pipe.py @@ -190,8 +190,8 @@ def __repr__(self) -> str: class PipeIterator(Iterator[PipeItem]): - def __init__(self, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> None: - self.max_parallelism = max_parallelism + def __init__(self, max_parallel_items: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> None: + self.max_parallel_items = max_parallel_items self.worker_threads = worker_threads self.futures_poll_interval = futures_poll_interval @@ -275,7 +275,7 @@ def __next__(self) -> PipeItem: if isinstance(pipe_item.item, Awaitable) or callable(pipe_item.item): # do we have a free slot or one of the slots is done? - if len(self._futures) < self.max_parallelism or self._next_future() >= 0: + if len(self._futures) < self.max_parallel_items or self._next_future() >= 0: if isinstance(pipe_item.item, Awaitable): future = asyncio.run_coroutine_threadsafe(pipe_item.item, self._ensure_async_pool()) else: diff --git a/tests/cases.py b/tests/cases.py index b17f666d0c..8da604528b 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -15,7 +15,7 @@ "big_decimal": Decimal("115792089237316195423570985008687907853269984665640564039457584007913129639935.1"), "datetime": pendulum.parse("2005-04-02T20:37:37.358236Z"), "date": pendulum.parse("2022-02-02").date(), - "uuid": UUID(_UUID), + # "uuid": UUID(_UUID), "hexbytes": HexBytes("0x2137"), "bytes": b'2137', "wei": Wei.from_int256(2137, decimals=2) @@ -26,8 +26,8 @@ "decimal": "decimal", "big_decimal": "decimal", "datetime": "timestamp", - "date": "text", - "uuid": "text", + "date": "timestamp", + # "uuid": "text", "hexbytes": "binary", "bytes": "binary", "wei": "wei" diff --git a/tests/common/schema/test_coercion.py b/tests/common/schema/test_coercion.py index 68170b2ee8..660ee22ce2 100644 --- a/tests/common/schema/test_coercion.py +++ b/tests/common/schema/test_coercion.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping, MutableSequence from typing import Any, Type import pytest import datetime # noqa: I251 @@ -198,13 +199,24 @@ def test_py_type_to_sc_type() -> None: assert utils.py_type_to_sc_type(int) == "bigint" assert utils.py_type_to_sc_type(float) == "double" assert utils.py_type_to_sc_type(str) == "text" - # unknown types are recognized as text - assert utils.py_type_to_sc_type(Exception) == "text" assert utils.py_type_to_sc_type(type(pendulum.now())) == "timestamp" assert utils.py_type_to_sc_type(type(datetime.datetime(1988, 12, 1))) == "timestamp" assert utils.py_type_to_sc_type(type(Decimal(1))) == "decimal" assert utils.py_type_to_sc_type(type(HexBytes("0xFF"))) == "binary" assert utils.py_type_to_sc_type(type(Wei.from_int256(2137, decimals=2))) == "wei" + # unknown types raise TypeException + with pytest.raises(TypeError): + utils.py_type_to_sc_type(Any) + # none type raises TypeException + with pytest.raises(TypeError): + utils.py_type_to_sc_type(type(None)) + # complex types + assert utils.py_type_to_sc_type(list) == "complex" + # assert utils.py_type_to_sc_type(set) == "complex" + assert utils.py_type_to_sc_type(dict) == "complex" + assert utils.py_type_to_sc_type(tuple) == "complex" + assert utils.py_type_to_sc_type(Mapping) == "complex" + assert utils.py_type_to_sc_type(MutableSequence) == "complex" def test_coerce_type_complex() -> None: diff --git a/tests/common/test_configuration.py b/tests/common/test_configuration.py index 47a8354ed8..de5d7d8291 100644 --- a/tests/common/test_configuration.py +++ b/tests/common/test_configuration.py @@ -1,14 +1,17 @@ import pytest from os import environ -from typing import Any, Dict, List, Mapping, MutableMapping, NewType, Optional, Tuple, Type +import datetime # noqa: I251 +from typing import Any, Dict, List, Mapping, MutableMapping, NewType, Optional, Sequence, Tuple, Type -from dlt.common.typing import TSecretValue +from dlt.common import pendulum, Decimal, Wei +from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, LookupTrace +from dlt.common.typing import StrAny, TSecretValue, extract_inner_type from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, ConfigEnvValueCannotBeCoercedException, resolve -from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration +from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration from dlt.common.configuration.providers import environ as environ_provider from dlt.common.utils import custom_environ -from tests.utils import preserve_environ +from tests.utils import preserve_environ, add_config_dict_to_env # used to test version __version__ = "1.0.5" @@ -22,28 +25,34 @@ 'a': 1, "b": "2" }, - 'tuple_val': (1, 2, '7'), - 'set_val': {1, 2, 3}, 'bytes_val': b'Hello World!', 'float_val': 1.18927, + "tuple_val": (1, 2, {1: "complicated dicts allowed in literal eval"}), 'any_val': "function() {}", 'none_val': "none", 'COMPLEX_VAL': { - "_": (1440, ["*"], []), - "change-email": (560, ["*"], []) - } + "_": [1440, ["*"], []], + "change-email": [560, ["*"], []] + }, + "date_val": pendulum.now(), + "dec_val": Decimal("22.38"), + "sequence_val": ["A", "B", "KAPPA"], + "gen_list_val": ["C", "Z", "N"], + "mapping_val": {"FL": 1, "FR": {"1": 2}}, + "mutable_mapping_val": {"str": "str"} } INVALID_COERCIONS = { # 'STR_VAL': 'test string', # string always OK 'int_val': "a12345", - 'bool_val': "Yes", # bool overridden by string - that is the most common problem - 'list_val': {1, "2", 3.0}, + 'bool_val': "not_bool", # bool overridden by string - that is the most common problem + 'list_val': {2: 1, "2": 3.0}, 'dict_val': "{'a': 1, 'b', '2'}", - 'tuple_val': [1, 2, '7'], - 'set_val': [1, 2, 3], 'bytes_val': 'Hello World!', - 'float_val': "invalid" + 'float_val': "invalid", + "tuple_val": "{1:2}", + "date_val": "01 May 2022", + "dec_val": True } EXCEPTED_COERCIONS = { @@ -94,13 +103,19 @@ class CoercionTestConfiguration(RunConfiguration): bool_val: bool = None list_val: list = None # type: ignore dict_val: dict = None # type: ignore - tuple_val: tuple = None # type: ignore bytes_val: bytes = None - set_val: set = None # type: ignore float_val: float = None + tuple_val: Tuple[int, int, StrAny] = None any_val: Any = None none_val: str = None COMPLEX_VAL: Dict[str, Tuple[int, List[str], List[str]]] = None + date_val: datetime.datetime = None + dec_val: Decimal = None + sequence_val: Sequence[str] = None + gen_list_val: List[str] = None + mapping_val: StrAny = None + mutable_mapping_val: MutableMapping[str, str] = None + @configspec @@ -176,6 +191,11 @@ class EmbeddedConfiguration(BaseConfiguration): namespaced: NamespacedConfiguration +@configspec +class EmbeddedOptionalConfiguration(BaseConfiguration): + instrumented: Optional[InstrumentedConfiguration] + + LongInteger = NewType("LongInteger", int) FirstOrderStr = NewType("FirstOrderStr", str) SecondOrderStr = NewType("SecondOrderStr", FirstOrderStr) @@ -187,7 +207,17 @@ def environment() -> Any: return environ -def test_initial_config_value() -> None: +def test_initial_config_state() -> None: + assert BaseConfiguration.__is_resolved__ is False + assert BaseConfiguration.__namespace__ is None + C = BaseConfiguration() + assert C.__is_resolved__ is False + assert C.is_resolved() is False + # base configuration has no resolvable fields so is never partial + assert C.is_partial() is False + + +def test_set_initial_config_value(environment: Any) -> None: # set from init method C = resolve.make_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) assert C.to_native_representation() == "h>a>b>he" @@ -201,7 +231,7 @@ def test_initial_config_value() -> None: assert C.to_native_representation() == "h>tu>be>xhe" -def test_check_integrity() -> None: +def test_check_integrity(environment: Any) -> None: with pytest.raises(RuntimeError): # head over hells resolve.make_configuration(InstrumentedConfiguration(), initial_value="he>a>b>h") @@ -223,16 +253,16 @@ def test_embedded_config(environment: Any) -> None: # resolve partial, partial is passed to embedded C = resolve.make_configuration(EmbeddedConfiguration(), accept_partial=True) - assert C.__is_partial__ - assert C.namespaced.__is_partial__ - assert C.instrumented.__is_partial__ + assert not C.__is_resolved__ + assert not C.namespaced.__is_resolved__ + assert not C.instrumented.__is_resolved__ # some are partial, some are not with custom_environ({"DLT_TEST__PASSWORD": "passwd"}): C = resolve.make_configuration(EmbeddedConfiguration(), accept_partial=True) - assert C.__is_partial__ - assert not C.namespaced.__is_partial__ - assert C.instrumented.__is_partial__ + assert not C.__is_resolved__ + assert C.namespaced.__is_resolved__ + assert not C.instrumented.__is_resolved__ # single integrity error fails all the embeds with custom_environ({"INSTRUMENTED": "he>tu>u>be>h"}): @@ -249,7 +279,13 @@ def test_provider_values_over_initial(environment: Any) -> None: with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): C = resolve.make_configuration(EmbeddedConfiguration(), initial_value={"instrumented": "h>tu>be>xhe"}, accept_partial=True) assert C.instrumented.to_native_representation() == "h>tu>u>be>he" - assert not C.instrumented.__is_partial__ + # parent configuration is not resolved + assert not C.is_resolved() + assert C.is_partial() + # but embedded is + assert C.instrumented.__is_resolved__ + assert C.instrumented.is_resolved() + assert not C.instrumented.is_partial() def test_run_configuration_gen_name(environment: Any) -> None: @@ -291,7 +327,7 @@ def test_configuration_is_mutable_mapping(environment: Any) -> None: assert C[key] == expected_dict[key] # version is present as attr but not present in dict assert hasattr(C, "_version") - assert hasattr(C, "__is_partial__") + assert hasattr(C, "__is_resolved__") assert hasattr(C, "__namespace__") with pytest.raises(KeyError): @@ -318,19 +354,19 @@ def test_configuration_is_mutable_mapping(environment: Any) -> None: C["_version"] = "1.1.1" -def test_fields_with_no_default_to_null() -> None: +def test_fields_with_no_default_to_null(environment: Any) -> None: # fields with no default are promoted to class attrs with none assert FieldWithNoDefaultConfiguration.no_default is None assert FieldWithNoDefaultConfiguration().no_default is None -def test_init_method_gen() -> None: +def test_init_method_gen(environment: Any) -> None: C = FieldWithNoDefaultConfiguration(no_default="no_default", sentry_dsn="SENTRY") assert C.no_default == "no_default" assert C.sentry_dsn == "SENTRY" -def test_multi_derivation_defaults() -> None: +def test_multi_derivation_defaults(environment: Any) -> None: @configspec class MultiConfiguration(MockProdConfiguration, ConfigurationWithOptionalTypes, NamespacedConfiguration): @@ -346,104 +382,165 @@ class MultiConfiguration(MockProdConfiguration, ConfigurationWithOptionalTypes, assert C.__namespace__ == "DLT_TEST" -def test_raises_on_unresolved_fields() -> None: - with pytest.raises(ConfigEntryMissingException) as config_entry_missing_exception: - C = WrongConfiguration() - keys = resolve._get_resolvable_fields(C) - resolve._is_config_bounded(C, keys) - - assert 'NONECONFIGVAR' in config_entry_missing_exception.value.missing_set - +def test_raises_on_unresolved_field(environment: Any) -> None: # via make configuration - with pytest.raises(ConfigEntryMissingException) as config_entry_missing_exception: + with pytest.raises(ConfigEntryMissingException) as cf_missing_exc: resolve.make_configuration(WrongConfiguration()) - assert 'NONECONFIGVAR' in config_entry_missing_exception.value.missing_set + assert cf_missing_exc.value.spec_name == "WrongConfiguration" + assert "NoneConfigVar" in cf_missing_exc.value.traces + # has only one trace + trace = cf_missing_exc.value.traces["NoneConfigVar"] + assert len(trace) == 1 + assert trace[0] == LookupTrace("Environment Variables", [], "NONECONFIGVAR", None) -def test_optional_types_are_not_required() -> None: - # this should not raise an exception - keys = resolve._get_resolvable_fields(ConfigurationWithOptionalTypes()) - resolve._is_config_bounded(ConfigurationWithOptionalTypes(), keys) +def test_raises_on_many_unresolved_fields(environment: Any) -> None: + # via make configuration + with pytest.raises(ConfigEntryMissingException) as cf_missing_exc: + resolve.make_configuration(CoercionTestConfiguration()) + assert cf_missing_exc.value.spec_name == "CoercionTestConfiguration" + # get all fields that must be set + val_fields = [f for f in CoercionTestConfiguration().get_resolvable_fields() if f.lower().endswith("_val")] + traces = cf_missing_exc.value.traces + assert len(traces) == len(val_fields) + for tr_field, exp_field in zip(traces, val_fields): + assert len(traces[tr_field]) == 1 + assert traces[tr_field][0] == LookupTrace("Environment Variables", [], environ_provider.EnvironProvider.get_key_name(exp_field), None) + + +def test_accepts_optional_missing_fields(environment: Any) -> None: + # ConfigurationWithOptionalTypes has values for all non optional fields present + C = ConfigurationWithOptionalTypes() + assert not C.is_partial() # make optional config resolve.make_configuration(ConfigurationWithOptionalTypes()) # make config with optional values - resolve.make_configuration(ProdConfigurationWithOptionalTypes(), initial_value={"INT_VAL": None}) + resolve.make_configuration(ProdConfigurationWithOptionalTypes(), initial_value={"int_val": None}) + # make config with optional embedded config + C = resolve.make_configuration(EmbeddedOptionalConfiguration()) + # embedded config was not fully resolved + assert not C.instrumented.__is_resolved__ + assert not C.instrumented.is_resolved() + assert C.instrumented.is_partial() -def test_configuration_apply_adds_environment_variable_to_config(environment: Any) -> None: +def test_resolves_from_environ(environment: Any) -> None: environment["NONECONFIGVAR"] = "Some" C = WrongConfiguration() - keys = resolve._get_resolvable_fields(C) - resolve._resolve_config_fields(C, keys, accept_partial=False) - resolve._is_config_bounded(C, keys) + resolve._resolve_config_fields(C, accept_partial=False) + assert not C.is_partial() assert C.NoneConfigVar == environment["NONECONFIGVAR"] -def test_configuration_resolve_env_var(environment: Any) -> None: - environment["TEST_BOOL"] = 'True' +def test_resolves_from_environ_with_coercion(environment: Any) -> None: + environment["TEST_BOOL"] = 'yes' C = SimpleConfiguration() - keys = resolve._get_resolvable_fields(C) - resolve._resolve_config_fields(C, keys, accept_partial=False) - resolve._is_config_bounded(C, keys) + resolve._resolve_config_fields(C, accept_partial=False) + assert not C.is_partial() # value will be coerced to bool assert C.test_bool is True def test_find_all_keys() -> None: - keys = resolve._get_resolvable_fields(VeryWrongConfiguration()) + keys = VeryWrongConfiguration().get_resolvable_fields() # assert hints and types: LOG_COLOR had it hint overwritten in derived class assert set({'str_val': str, 'int_val': int, 'NoneConfigVar': str, 'log_color': str}.items()).issubset(keys.items()) -def test_coercions(environment: Any) -> None: - for key, value in COERCIONS.items(): - environment[key.upper()] = str(value) +def test_coercion_to_hint_types(environment: Any) -> None: + add_config_dict_to_env(COERCIONS) C = CoercionTestConfiguration() - keys = resolve._get_resolvable_fields(C) - resolve._resolve_config_fields(C, keys, accept_partial=False) - resolve._is_config_bounded(C, keys) + resolve._resolve_config_fields(C, accept_partial=False) for key in COERCIONS: assert getattr(C, key) == COERCIONS[key] +def test_values_serialization() -> None: + # test tuple + t_tuple = (1, 2, 3, "A") + v = resolve.serialize_value(t_tuple) + assert v == "(1, 2, 3, 'A')" # literal serialization + assert resolve.deserialize_value("K", v, tuple) == t_tuple + + # test list + t_list = ["a", 3, True] + v = resolve.serialize_value(t_list) + assert v == '["a", 3, true]' # json serialization + assert resolve.deserialize_value("K", v, list) == t_list + + # test datetime + t_date = pendulum.now() + v = resolve.serialize_value(t_date) + assert resolve.deserialize_value("K", v, datetime.datetime) == t_date + + # test wei + t_wei = Wei.from_int256(10**16, decimals=18) + v = resolve.serialize_value(t_wei) + assert v == "0.01" + # can be deserialized into + assert resolve.deserialize_value("K", v, float) == 0.01 + assert resolve.deserialize_value("K", v, Decimal) == Decimal("0.01") + assert resolve.deserialize_value("K", v, Wei) == Wei("0.01") + + def test_invalid_coercions(environment: Any) -> None: C = CoercionTestConfiguration() - config_keys = resolve._get_resolvable_fields(C) + add_config_dict_to_env(INVALID_COERCIONS) for key, value in INVALID_COERCIONS.items(): try: - environment[key.upper()] = str(value) - resolve._resolve_config_fields(C, config_keys, accept_partial=False) + resolve._resolve_config_fields(C, accept_partial=False) except ConfigEnvValueCannotBeCoercedException as coerc_exc: # must fail exactly on expected value - if coerc_exc.attr_name != key: + if coerc_exc.field_name != key: raise # overwrite with valid value and go to next env - environment[key.upper()] = str(COERCIONS[key]) + environment[key.upper()] = resolve.serialize_value(COERCIONS[key]) continue raise AssertionError("%s was coerced with %s which is invalid type" % (key, value)) def test_excepted_coercions(environment: Any) -> None: C = CoercionTestConfiguration() - config_keys = resolve._get_resolvable_fields(C) - for k, v in EXCEPTED_COERCIONS.items(): - environment[k.upper()] = str(v) - resolve._resolve_config_fields(C, config_keys, accept_partial=False) + add_config_dict_to_env(COERCIONS) + add_config_dict_to_env(EXCEPTED_COERCIONS, overwrite_keys=True) + resolve._resolve_config_fields(C, accept_partial=False) for key in EXCEPTED_COERCIONS: assert getattr(C, key) == COERCED_EXCEPTIONS[key] +def test_config_with_unsupported_types_in_hints(environment: Any) -> None: + with pytest.raises(ConfigFieldTypeHintNotSupported): + + @configspec + class InvalidHintConfiguration(BaseConfiguration): + tuple_val: tuple = None # type: ignore + set_val: set = None # type: ignore + InvalidHintConfiguration() + + +def test_config_with_no_hints(environment: Any) -> None: + with pytest.raises(ConfigFieldMissingTypeHintException): + + @configspec + class NoHintConfiguration(BaseConfiguration): + tuple_val = None + NoHintConfiguration() + + + + + def test_make_configuration(environment: Any) -> None: # fill up configuration environment["NONECONFIGVAR"] = "1" C = resolve.make_configuration(WrongConfiguration()) - assert not C.__is_partial__ + assert C.__is_resolved__ assert C.NoneConfigVar == "1" @@ -482,7 +579,8 @@ def test_accept_partial(environment: Any) -> None: C = resolve.make_configuration(WrongConfiguration(), accept_partial=True) assert C.NoneConfigVar is None # partial resolution - assert C.__is_partial__ + assert not C.__is_resolved__ + assert C.is_partial() def test_finds_version(environment: Any) -> None: @@ -583,8 +681,13 @@ def test_configuration_files(environment: Any) -> None: def test_namespaced_configuration(environment: Any) -> None: with pytest.raises(ConfigEntryMissingException) as exc_val: resolve.make_configuration(NamespacedConfiguration()) - assert exc_val.value.missing_set == ["DLT_TEST__PASSWORD"] - assert exc_val.value.namespace == "DLT_TEST" + assert list(exc_val.value.traces.keys()) == ["password"] + assert exc_val.value.spec_name == "NamespacedConfiguration" + # check trace + traces = exc_val.value.traces["password"] + # only one provider and namespace was tried + assert len(traces) == 1 + assert traces[0] == LookupTrace("Environment Variables", ("DLT_TEST",), "DLT_TEST__PASSWORD", None) # init vars work without namespace C = resolve.make_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) @@ -598,6 +701,7 @@ def test_namespaced_configuration(environment: Any) -> None: C = resolve.make_configuration(NamespacedConfiguration()) assert C.password == "PASS" + def coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: - hint = resolve._extract_simple_type(hint) - return resolve._coerce_single_value(key, value, hint) + hint = extract_inner_type(hint) + return resolve.deserialize_value(key, value, hint) diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index c80461dd59..8fdb5afbb0 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -1,7 +1,7 @@ -from typing import List, Literal, Mapping, MutableMapping, MutableSequence, Sequence, TypedDict, Optional +from typing import List, Literal, Mapping, MutableMapping, MutableSequence, NewType, Sequence, TypeVar, TypedDict, Optional -from dlt.common.typing import extract_optional_type, is_dict_generic_type, is_list_generic_type, is_literal_type, is_optional_type, is_typeddict +from dlt.common.typing import extract_inner_type, extract_optional_type, is_dict_generic_type, is_list_generic_type, is_literal_type, is_newtype_type, is_optional_type, is_typeddict @@ -20,7 +20,7 @@ def test_is_typeddict() -> None: assert is_typeddict(Sequence[str]) is False -def test_is_list_type() -> None: +def test_is_list_generic_type() -> None: # yes - we need a generic type assert is_list_generic_type(list) is False assert is_list_generic_type(List[str]) is True @@ -28,13 +28,13 @@ def test_is_list_type() -> None: assert is_list_generic_type(MutableSequence[str]) is True -def test_is_dict_type() -> None: +def test_is_dict_generic_type() -> None: assert is_dict_generic_type(dict) is False assert is_dict_generic_type(Mapping[str, str]) is True assert is_dict_generic_type(MutableMapping[str, str]) is True -def test_literal() -> None: +def test_is_literal() -> None: assert is_literal_type(TTestLi) is True assert is_literal_type("a") is False assert is_literal_type(List[str]) is False @@ -46,3 +46,22 @@ def test_optional() -> None: assert is_optional_type(TTestTyDi) is False assert extract_optional_type(TOptionalLi) is TTestLi assert extract_optional_type(TOptionalTyDi) is TTestTyDi + + +def test_is_newtype() -> None: + assert is_newtype_type(NewType("NT1", str)) is True + assert is_newtype_type(TypeVar("TV1", bound=str)) is False + assert is_newtype_type(1) is False + + +def test_extract_inner_type() -> None: + assert extract_inner_type(1) == 1 + assert extract_inner_type(str) is str + assert extract_inner_type(NewType("NT1", str)) is str + assert extract_inner_type(NewType("NT2", NewType("NT3", int))) is int + assert extract_inner_type(Optional[NewType("NT3", bool)]) is bool + l_1 = Literal[1, 2, 3] + assert extract_inner_type(l_1) is int + nt_l_2 = NewType("NTL2", float) + l_2 = Literal[nt_l_2(1.238), nt_l_2(2.343)] + assert extract_inner_type(l_2) is float diff --git a/tests/utils.py b/tests/utils.py index c3a769a4c7..adf954899c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,13 +1,14 @@ import multiprocessing import platform +from typing import Any, Mapping import requests -from typing import Type import pytest import logging from os import environ -from dlt.common.configuration.resolve import _get_resolvable_fields, make_configuration -from dlt.common.configuration.specs import RunConfiguration +from dlt.common.configuration.providers import EnvironProvider +from dlt.common.configuration.resolve import make_configuration, serialize_value +from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration from dlt.common.logger import init_logging_from_config from dlt.common.storages import FileStorage from dlt.common.schema import Schema @@ -77,15 +78,17 @@ def clean_test_storage(init_normalize: bool = False, init_loader: bool = False) return storage -def add_config_to_env(config: RunConfiguration) -> None: +def add_config_to_env(config: BaseConfiguration) -> None: # write back default values in configuration back into environment - possible_attrs = _get_resolvable_fields(config).keys() - for attr in possible_attrs: - env_key = attr.upper() - if env_key not in environ: - v = getattr(config, attr) + return add_config_dict_to_env(dict(config), config.__namespace__) + + +def add_config_dict_to_env(dict_: Mapping[str, Any], namespace: str = None, overwrite_keys: bool = False) -> None: + for k, v in dict_.items(): + env_key = EnvironProvider.get_key_name(k, namespace) + if env_key not in environ or overwrite_keys: if v is not None: - environ[env_key] = str(v) + environ[env_key] = serialize_value(v) def create_schema_with_name(schema_name) -> Schema: From f16285b3766020f05e10d44949b382f5e7cb6145 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 5 Oct 2022 23:12:30 +0200 Subject: [PATCH 26/66] moves config tests --- tests/common/{ => configuration}/test_configuration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/common/{ => configuration}/test_configuration.py (99%) diff --git a/tests/common/test_configuration.py b/tests/common/configuration/test_configuration.py similarity index 99% rename from tests/common/test_configuration.py rename to tests/common/configuration/test_configuration.py index de5d7d8291..c19f079ea8 100644 --- a/tests/common/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -687,7 +687,7 @@ def test_namespaced_configuration(environment: Any) -> None: traces = exc_val.value.traces["password"] # only one provider and namespace was tried assert len(traces) == 1 - assert traces[0] == LookupTrace("Environment Variables", ("DLT_TEST",), "DLT_TEST__PASSWORD", None) + assert traces[0] == LookupTrace("Environment Variables", ["DLT_TEST"], "DLT_TEST__PASSWORD", None) # init vars work without namespace C = resolve.make_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) From a80b2f5c6e3baab381e557e4404eadee084b0e97 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 12 Oct 2022 22:45:03 +0200 Subject: [PATCH 27/66] adds config inject decorator, basic tests, applies decorator to buffered writers and pipe --- dlt/common/configuration/__init__.py | 1 + dlt/common/configuration/inject.py | 138 ++++++++++++++++-- dlt/common/data_writers/buffered.py | 7 +- experiments/pipeline/configuration.py | 75 ---------- experiments/pipeline/pipe.py | 17 ++- tests/common/configuration/test_inject.py | 164 ++++++++++++++++++++++ 6 files changed, 308 insertions(+), 94 deletions(-) create mode 100644 tests/common/configuration/test_inject.py diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 0459774597..212d7b0575 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,5 +1,6 @@ from .specs.base_configuration import configspec, is_valid_hint # noqa: F401 from .resolve import make_configuration # noqa: F401 +from .inject import with_config from .exceptions import ( # noqa: F401 ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, ConfigIntegrityException, ConfigFileNotFoundException) diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index 55f0b888e6..fe48c82de6 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -1,25 +1,135 @@ -from typing import Dict, Type, TypeVar +import re +import inspect +from makefun import wraps +from types import ModuleType +from typing import Callable, List, Dict, Type, Any, Optional, Tuple, overload +from inspect import Signature, Parameter -from dlt.common.configuration.specs.base_configuration import BaseConfiguration +from dlt.common.typing import StrAny, TFun, AnyFun +from dlt.common.configuration.resolve import make_configuration +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, is_valid_hint, configspec -TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) +# [^.^_]+ splits by . or _ +_SLEEPING_CAT_SPLIT = re.compile("[^.^_]+") -class Container: +@overload +def with_config(func: TFun, /, spec: Type[BaseConfiguration] = None, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> TFun: + ... - _INSTANCE: "Container" = None - def __new__(cls: Type["Container"]) -> "Container": - if not cls._INSTANCE: - cls._INSTANCE = super().__new__(cls) - return cls._INSTANCE +@overload +def with_config(func: None = ..., /, spec: Type[BaseConfiguration] = None, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]: + ... - def __init__(self) -> None: - self.configurations: Dict[Type[BaseConfiguration], BaseConfiguration] = {} +def with_config(func: Optional[AnyFun] = None, /, spec: Type[BaseConfiguration] = None, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]: + namespace_f: Callable[[StrAny], str] = None + # namespace may be a function from function arguments to namespace + if callable(namespaces): + namespace_f = namespaces - def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration: - # return existing config object or create it from spec - return self.configurations.setdefault(spec, spec()) + def decorator(f: TFun) -> TFun: + SPEC: Type[BaseConfiguration] = None + sig: Signature = inspect.signature(f) + kwargs_par = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None) + if spec is None: + SPEC = _spec_from_signature(_get_spec_name_from_f(f), inspect.getmodule(f), sig, only_kw) + else: + SPEC = spec + + # for all positional parameters that do not have default value, set default + for p in sig.parameters.values(): + if hasattr(SPEC, p.name) and p.default == Parameter.empty: + p._default = None # type: ignore + + @wraps(f, new_sig=sig) + def _wrap(*args: Any, **kwargs: Any) -> Any: + # for calls providing all parameters to the func, configuration may not be resolved + # if len(args) + len(kwargs) == len(sig.parameters): + # return f(*args, **kwargs) + + # bind parameters to signature + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + # if namespace derivation function was provided then call it + nonlocal namespaces + if namespace_f: + namespaces = (namespace_f(bound_args.arguments), ) + # namespaces may be a string + if isinstance(namespaces, str): + namespaces = (namespaces,) + # resolve SPEC + config = make_configuration(SPEC(), namespaces=namespaces, initial_value=bound_args.arguments) + resolved_params = dict(config) + # overwrite or add resolved params + for p in sig.parameters.values(): + if p.name in resolved_params: + bound_args.arguments[p.name] = resolved_params.pop(p.name) + # pass all other config parameters into kwargs if present + if kwargs_par is not None: + bound_args.arguments[kwargs_par.name].update(resolved_params) + # call the function with injected config + return f(*bound_args.args, **bound_args.kwargs) + + return _wrap # type: ignore + + # See if we're being called as @with_config or @with_config(). + if func is None: + # We're called with parens. + return decorator + + if not callable(func): + raise ValueError("First parameter to the with_config must be callable ie. by using it as function decorator") + + # We're called as @with_config without parens. + return decorator(func) + + +def _get_spec_name_from_f(f: AnyFun) -> str: + func_name = f.__qualname__.replace(".", "") # func qual name contains position in the module, separated by dots + + def _first_up(s: str) -> str: + return s[0].upper() + s[1:] + + return "".join(map(_first_up, _SLEEPING_CAT_SPLIT.findall(func_name))) + "Configuration" + + +def _spec_from_signature(name: str, module: ModuleType, sig: Signature, kw_only: bool = False) -> Type[BaseConfiguration]: + # synthesize configuration from the signature + fields: Dict[str, Any] = {} + annotations: Dict[str, Any] = {} + + for p in sig.parameters.values(): + # skip *args and **kwargs, skip typical method params and if kw_only flag is set: accept KEYWORD ONLY args + if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in ["self", "cls"] and \ + (kw_only and p.kind == Parameter.KEYWORD_ONLY or not kw_only): + field_type = Any if p.annotation == Parameter.empty else p.annotation + if is_valid_hint(field_type): + field_default = None if p.default == Parameter.empty else p.default + # try to get type from default + if field_type is Any and field_default: + field_type = type(field_default) + # make type optional if explicit None is provided as default + if p.default is None: + field_type = Optional[field_type] + # set annotations + annotations[p.name] = field_type + # set field with default value + + fields[p.name] = field_default + # new type goes to the module where sig was declared + fields["__module__"] = module.__name__ + # set annotations so they are present in __dict__ + fields["__annotations__"] = annotations + # synthesize type + T: Type[BaseConfiguration] = type(name, (BaseConfiguration,), fields) + # add to the module + setattr(module, name, T) + SPEC = configspec(init=False)(T) + # print(f"SYNTHESIZED {SPEC} in {inspect.getmodule(SPEC)} for sig {sig}") + # import dataclasses + # print("\n".join(map(str, dataclasses.fields(SPEC)))) + return SPEC diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index f27db9a65a..7b5e3f1054 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,4 +1,4 @@ -from typing import List, IO, Any +from typing import List, IO, Any, Optional from dlt.common.utils import uniq_id from dlt.common.typing import TDataItem @@ -7,10 +7,13 @@ from dlt.common.data_writers.exceptions import BufferedDataWriterClosed, InvalidFileNameTemplateException from dlt.common.data_writers.writers import DataWriter from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.configuration import with_config class BufferedDataWriter: - def __init__(self, file_format: TLoaderFileFormat, file_name_template: str, buffer_max_items: int = 5000, file_max_items: int = None, file_max_bytes: int = None): + + @with_config(only_kw=True, namespaces=("data_writer",)) + def __init__(self, file_format: TLoaderFileFormat, file_name_template: str, *, buffer_max_items: int = 5000, file_max_items: int = None, file_max_bytes: int = None): self.file_format = file_format self._file_format_spec = DataWriter.data_format_from_file_format(self.file_format) # validate if template has correct placeholders diff --git a/experiments/pipeline/configuration.py b/experiments/pipeline/configuration.py index 56697f4538..159f13c4c4 100644 --- a/experiments/pipeline/configuration.py +++ b/experiments/pipeline/configuration.py @@ -36,78 +36,3 @@ def get_config(SPEC: Type[TAny], key: str = None, namespace: str = None, initial def spec_from_dict(): pass - -def spec_from_signature(name: str, sig: Signature) -> Type[BaseConfiguration]: - # synthesize configuration from the signature - fields: List[dataclasses.Field] = [] - for p in sig.parameters.values(): - # skip *args and **kwargs - if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in ["self", "cls"]: - field_type = Any if p.annotation == Parameter.empty else p.annotation - if is_valid_hint(field_type): - field_default = None if p.default == Parameter.empty else dataclasses.field(default=p.default) - if field_default: - # correct the type if Any - if field_type is Any: - field_type = type(p.default) - fields.append((p.name, field_type, field_default)) - else: - fields.append((p.name, field_type)) - print(fields) - SPEC = dataclasses.make_dataclass(name + "_CONFIG", fields, bases=(BaseConfiguration,), init=False) - print("synthesized") - print(SPEC) - # print(SPEC()) - return SPEC - - -def with_config(func = None, /, spec: Type[BaseConfiguration] = None) -> TFun: - - def decorator(f: TFun) -> TFun: - SPEC: Type[BaseConfiguration] = None - sig: Signature = inspect.signature(f) - kwargs_par = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None) - # pos_params = [p.name for p in sig.parameters.values() if p.kind in _POS_PARAMETER_KINDS] - # kw_params = [p.name for p in sig.parameters.values() if p.kind not in _POS_PARAMETER_KINDS] - - if spec is None: - SPEC = spec_from_signature(f.__name__, sig) - else: - SPEC = spec - - @wraps(f) - def _wrap(*args: Any, **kwargs: Any) -> Any: - # for calls providing all parameters to the func, configuration may not be resolved - # if len(args) + len(kwargs) == len(sig.parameters): - # return f(*args, **kwargs) - - # bind parameters to signature - bound_args = sig.bind_partial(*args, **kwargs) - bound_args.apply_defaults() - # resolve SPEC - config = get_config(SPEC, SPEC, initial_value=bound_args.arguments) - resolved_params = dict(config) - print("RESOLVED") - print(resolved_params) - # overwrite or add resolved params - for p in sig.parameters.values(): - if p.name in resolved_params: - bound_args.arguments[p.name] = resolved_params.pop(p.name) - # pass all other config parameters into kwargs if present - if kwargs_par is not None: - bound_args.arguments[kwargs_par.name].update(resolved_params) - # call the function with injected config - return f(*bound_args.args, **bound_args.kwargs) - - return _wrap - - # See if we're being called as @with_config or @with_config(). - if func is None: - # We're called with parens. - return decorator - - if not callable(func): - raise ValueError("First parameter to the with_config must be callable ie. by using it as function decorator") - - # We're called as @with_config without parens. - return decorator(func) diff --git a/experiments/pipeline/pipe.py b/experiments/pipeline/pipe.py index 9b0a43dac0..da32a6c12f 100644 --- a/experiments/pipeline/pipe.py +++ b/experiments/pipeline/pipe.py @@ -5,6 +5,8 @@ from copy import deepcopy from threading import Thread from typing import Optional, Sequence, Union, Callable, Iterable, Iterator, List, NamedTuple, Awaitable, Tuple, Type, TYPE_CHECKING +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec from dlt.common.typing import TDataItem from dlt.common.sources import TDirectDataItem, TResolvableDataItem @@ -190,7 +192,14 @@ def __repr__(self) -> str: class PipeIterator(Iterator[PipeItem]): - def __init__(self, max_parallel_items: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> None: + @configspec + class PipeIteratorConfiguration: + max_parallel_items: int = 100 + worker_threads: int = 5 + futures_poll_interval: float = 0.01 + + + def __init__(self, max_parallel_items: int, worker_threads, futures_poll_interval: float) -> None: self.max_parallel_items = max_parallel_items self.worker_threads = worker_threads self.futures_poll_interval = futures_poll_interval @@ -202,7 +211,8 @@ def __init__(self, max_parallel_items: int = 100, worker_threads: int = 5, futur self._futures: List[FuturePipeItem] = [] @classmethod - def from_pipe(cls, pipe: Pipe, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + @with_config(spec=PipeIteratorConfiguration) + def from_pipe(cls, pipe: Pipe, *, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": if pipe.parent: pipe = pipe.full_pipe() # head must be iterator @@ -214,7 +224,8 @@ def from_pipe(cls, pipe: Pipe, max_parallelism: int = 100, worker_threads: int = return extract @classmethod - def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + @with_config(spec=PipeIteratorConfiguration) + def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, *, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": extract = cls(max_parallelism, worker_threads, futures_poll_interval) # clone all pipes before iterating (recursively) as we will fork them and this add steps pipes = PipeIterator.clone_pipes(pipes) diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py new file mode 100644 index 0000000000..7e9647c202 --- /dev/null +++ b/tests/common/configuration/test_inject.py @@ -0,0 +1,164 @@ +import inspect +from typing import Any, Optional + +from dlt.common import Decimal +from dlt.common.typing import TSecretValue +from dlt.common.configuration.inject import _spec_from_signature, _get_spec_name_from_f, with_config +from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration + + +_DECIMAL_DEFAULT = Decimal("0.01") +_SECRET_DEFAULT = TSecretValue("PASS") +_CONFIG_DEFAULT = RunConfiguration() + + +def test_synthesize_spec_from_sig() -> None: + + # spec from typed signature without defaults + + def f_typed(p1: str, p2: Decimal, p3: Any, p4: Optional[RunConfiguration], p5: TSecretValue) -> None: + pass + + SPEC = _spec_from_signature(f_typed.__name__, inspect.getmodule(f_typed), inspect.signature(f_typed)) + assert SPEC.p1 is None + assert SPEC.p2 is None + assert SPEC.p3 is None + assert SPEC.p4 is None + assert SPEC.p5 is None + fields = SPEC().get_resolvable_fields() + assert fields == {"p1": str, "p2": Decimal, "p3": Any, "p4": Optional[RunConfiguration], "p5": TSecretValue} + + # spec from typed signatures with defaults + + def f_typed_default(t_p1: str = "str", t_p2: Decimal = _DECIMAL_DEFAULT, t_p3: Any = _SECRET_DEFAULT, t_p4: RunConfiguration = _CONFIG_DEFAULT, t_p5: str = None) -> None: + pass + + SPEC = _spec_from_signature(f_typed_default.__name__, inspect.getmodule(f_typed_default), inspect.signature(f_typed_default)) + assert SPEC.t_p1 == "str" + assert SPEC.t_p2 == _DECIMAL_DEFAULT + assert SPEC.t_p3 == _SECRET_DEFAULT + assert isinstance(SPEC.t_p4, RunConfiguration) + assert SPEC.t_p5 is None + fields = SPEC().get_resolvable_fields() + # Any will not assume TSecretValue type because at runtime it's a str + # setting default as None will convert type into optional (t_p5) + assert fields == {"t_p1": str, "t_p2": Decimal, "t_p3": str, "t_p4": RunConfiguration, "t_p5": Optional[str]} + + # spec from untyped signature + + def f_untyped(untyped_p1, untyped_p2) -> None: + pass + + SPEC = _spec_from_signature(f_untyped.__name__, inspect.getmodule(f_untyped), inspect.signature(f_untyped)) + assert SPEC.untyped_p1 is None + assert SPEC.untyped_p2 is None + fields = SPEC().get_resolvable_fields() + assert fields == {"untyped_p1": Any, "untyped_p2": Any,} + + # spec types derived from defaults + + + def f_untyped_default(untyped_p1 = "str", untyped_p2 = _DECIMAL_DEFAULT, untyped_p3 = _CONFIG_DEFAULT, untyped_p4 = None) -> None: + pass + + + SPEC = _spec_from_signature(f_untyped_default.__name__, inspect.getmodule(f_untyped_default), inspect.signature(f_untyped_default)) + assert SPEC.untyped_p1 == "str" + assert SPEC.untyped_p2 == _DECIMAL_DEFAULT + assert isinstance(SPEC.untyped_p3, RunConfiguration) + assert SPEC.untyped_p4 is None + fields = SPEC().get_resolvable_fields() + # untyped_p4 converted to Optional[Any] + assert fields == {"untyped_p1": str, "untyped_p2": Decimal, "untyped_p3": RunConfiguration, "untyped_p4": Optional[Any]} + + # spec from signatures containing positional only and keywords only args + + def f_pos_kw_only(pos_only_1, pos_only_2: str = "default", /, *, kw_only_1, kw_only_2: int = 2) -> None: + pass + + SPEC = _spec_from_signature(f_pos_kw_only.__name__, inspect.getmodule(f_pos_kw_only), inspect.signature(f_pos_kw_only)) + assert SPEC.pos_only_1 is None + assert SPEC.pos_only_2 == "default" + assert SPEC.kw_only_1 is None + assert SPEC.kw_only_2 == 2 + fields = SPEC().get_resolvable_fields() + assert fields == {"pos_only_1": Any, "pos_only_2": str, "kw_only_1": Any, "kw_only_2": int} + + # kw_only = True will filter in keywords only parameters + SPEC = _spec_from_signature(f_pos_kw_only.__name__, inspect.getmodule(f_pos_kw_only), inspect.signature(f_pos_kw_only), kw_only=True) + assert SPEC.kw_only_1 is None + assert SPEC.kw_only_2 == 2 + assert not hasattr(SPEC, "pos_only_1") + fields = SPEC().get_resolvable_fields() + assert fields == {"kw_only_1": Any, "kw_only_2": int} + + def f_variadic(var_1: str, *args, kw_var_1: str, **kwargs) -> None: + pass + + SPEC = _spec_from_signature(f_variadic.__name__, inspect.getmodule(f_variadic), inspect.signature(f_variadic)) + assert SPEC.var_1 is None + assert SPEC.kw_var_1 is None + assert not hasattr(SPEC, "args") + fields = SPEC().get_resolvable_fields() + assert fields == {"var_1": str, "kw_var_1": str} + + +def test_inject_with_non_injectable_param() -> None: + # one of parameters in signature has not valid hint and is skipped (ie. from_pipe) + pass + + +def test_inject_with_spec() -> None: + pass + + +def test_inject_with_str_namespaces() -> None: + # namespaces param is str not tuple + pass + + +def test_inject_with_func_namespace() -> None: + # function to get namespaces from the arguments is provided + pass + + +def test_inject_on_class_and_methods() -> None: + pass + + +def test_set_defaults_for_positional_args() -> None: + # set defaults for positional args that are part of derived SPEC + # set defaults for positional args that are part of provided SPEC + pass + + + +def test_auto_derived_spec_type_name() -> None: + + + class AutoNameTest: + @with_config + def __init__(self, pos_par, /, kw_par) -> None: + pass + + @classmethod + @with_config + def make_class(cls, pos_par, /, kw_par) -> None: + pass + + @staticmethod + @with_config + def make_stuff(pos_par, /, kw_par) -> None: + pass + + @with_config + def stuff_test(pos_par, /, kw_par) -> None: + pass + + # name is composed via __qualname__ of func + assert _get_spec_name_from_f(AutoNameTest.__init__) == "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" + # synthesized spec present in current module + assert "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" in globals() + # instantiate + C: BaseConfiguration = globals()["TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration"]() + assert C.get_resolvable_fields() == {"pos_par": Any, "kw_par": Any} \ No newline at end of file From 1e0695f5f46b92efba947e2bd1a8070324d818a9 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 12 Oct 2022 22:45:54 +0200 Subject: [PATCH 28/66] adds injection container with tests --- dlt/common/configuration/container.py | 62 +++++++++ .../configuration/providers/container.py | 31 +++++ tests/common/configuration/test_container.py | 118 ++++++++++++++++++ 3 files changed, 211 insertions(+) create mode 100644 dlt/common/configuration/container.py create mode 100644 dlt/common/configuration/providers/container.py create mode 100644 tests/common/configuration/test_container.py diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py new file mode 100644 index 0000000000..f28fa6e9fa --- /dev/null +++ b/dlt/common/configuration/container.py @@ -0,0 +1,62 @@ +from contextlib import contextmanager +from typing import Dict, Iterator, Type, TypeVar + +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec +from dlt.common.configuration.exceptions import ContainerInjectableConfigurationMangled + + +@configspec +class ContainerInjectableConfiguration(BaseConfiguration): + """Base class for all configurations that may be injected from Container.""" + pass + + +TConfiguration = TypeVar("TConfiguration", bound=ContainerInjectableConfiguration) + + +class Container: + + _INSTANCE: "Container" = None + + configurations: Dict[Type[ContainerInjectableConfiguration], ContainerInjectableConfiguration] + + def __new__(cls: Type["Container"]) -> "Container": + if not cls._INSTANCE: + cls._INSTANCE = super().__new__(cls) + cls._INSTANCE.configurations = {} + return cls._INSTANCE + + def __init__(self) -> None: + pass + + def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration: + # return existing config object or create it from spec + if not issubclass(spec, ContainerInjectableConfiguration): + raise KeyError(f"{spec.__name__} is not injectable") + + return self.configurations.setdefault(spec, spec()) # type: ignore + + def __contains__(self, spec: Type[TConfiguration]) -> bool: + return spec in self.configurations + + @contextmanager + def injectable_configuration(self, config: TConfiguration) -> Iterator[TConfiguration]: + spec = type(config) + previous_config: ContainerInjectableConfiguration = None + if spec in self.configurations: + previous_config = self.configurations[spec] + # set new config and yield context + try: + self.configurations[spec] = config + yield config + finally: + # before setting the previous config for given spec, check if there was no overlapping modification + if self.configurations[spec] is config: + # config is injected for spec so restore previous + if previous_config is None: + del self.configurations[spec] + else: + self.configurations[spec] = previous_config + else: + # value was modified in the meantime and not restored + raise ContainerInjectableConfigurationMangled(spec, self.configurations[spec], config) diff --git a/dlt/common/configuration/providers/container.py b/dlt/common/configuration/providers/container.py new file mode 100644 index 0000000000..1fc40559b1 --- /dev/null +++ b/dlt/common/configuration/providers/container.py @@ -0,0 +1,31 @@ +from typing import Any, Optional, Type, Tuple + +from dlt.common.configuration.container import Container + +from .provider import Provider + + +class ContainerProvider(Provider): + + NAME = "Injectable Configuration" + + @property + def name(self) -> str: + return ContainerProvider.NAME + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + assert namespaces == () + # get container singleton + container = Container() + if hint in container: + return Container()[hint], hint.__name__ + else: + return None, str(hint) + + @property + def supports_secrets(self) -> bool: + return True + + @property + def supports_namespaces(self) -> bool: + return False diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py new file mode 100644 index 0000000000..7a73a73b46 --- /dev/null +++ b/tests/common/configuration/test_container.py @@ -0,0 +1,118 @@ +from typing import Any +import pytest + +from dlt.common.configuration import configspec +from dlt.common.configuration.providers.container import ContainerProvider +from dlt.common.configuration.resolve import make_configuration +from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.configuration.container import Container, ContainerInjectableConfiguration +from dlt.common.configuration.exceptions import ContainerInjectableConfigurationMangled, InvalidInitialValue +from dlt.common.configuration.specs.config_providers_configuration import ConfigProvidersListConfiguration + +from tests.utils import preserve_environ +from tests.common.configuration.utils import environment + + +@configspec(init=True) +class InjectableTestConfiguration(ContainerInjectableConfiguration): + current_value: str + + +@configspec +class EmbeddedWithInjectableConfiguration(BaseConfiguration): + injected: InjectableTestConfiguration + + +@pytest.fixture() +def container() -> Container: + # erase singleton + Container._INSTANCE = None + return Container() + + +def test_singleton(container: Container) -> None: + # keep the old configurations list + container_configurations = container.configurations + + singleton = Container() + # make sure it is the same object + assert container is singleton + # that holds the same configurations dictionary + assert container_configurations is singleton.configurations + + +def test_get_default_injectable_config(container: Container) -> None: + pass + + +def test_container_injectable_context(container: Container) -> None: + with container.injectable_configuration(InjectableTestConfiguration()) as current_config: + assert current_config.current_value is None + current_config.current_value = "TEST" + assert container[InjectableTestConfiguration].current_value == "TEST" + assert container[InjectableTestConfiguration] is current_config + + assert InjectableTestConfiguration not in container + + +def test_container_injectable_context_restore(container: Container) -> None: + # this will create InjectableTestConfiguration + original = container[InjectableTestConfiguration] + original.current_value = "ORIGINAL" + with container.injectable_configuration(InjectableTestConfiguration()) as current_config: + current_config.current_value = "TEST" + # nested context is supported + with container.injectable_configuration(InjectableTestConfiguration()) as inner_config: + assert inner_config.current_value is None + assert container[InjectableTestConfiguration] is inner_config + assert container[InjectableTestConfiguration] is current_config + + assert container[InjectableTestConfiguration] is original + assert container[InjectableTestConfiguration].current_value == "ORIGINAL" + + +def test_container_injectable_context_mangled(container: Container) -> None: + original = container[InjectableTestConfiguration] + original.current_value = "ORIGINAL" + + injectable = InjectableTestConfiguration() + with pytest.raises(ContainerInjectableConfigurationMangled) as py_ex: + with container.injectable_configuration(injectable) as current_config: + current_config.current_value = "TEST" + # overwrite the config in container + container.configurations[InjectableTestConfiguration] = InjectableTestConfiguration() + assert py_ex.value.spec == InjectableTestConfiguration + assert py_ex.value.expected_config == injectable + + +def test_container_provider(container: Container) -> None: + provider = ContainerProvider() + v, k = provider.get_value("n/a", InjectableTestConfiguration) + # provider does not create default value in Container + assert v is None + assert k == str(InjectableTestConfiguration) + assert InjectableTestConfiguration not in container + + original = container[InjectableTestConfiguration] + original.current_value = "ORIGINAL" + v, _ = provider.get_value("n/a", InjectableTestConfiguration) + assert v is original + + # must assert if namespaces are provided + with pytest.raises(AssertionError): + provider.get_value("n/a", InjectableTestConfiguration, ("ns1",)) + + +def test_container_provider_embedded_inject(container: Container, environment: Any) -> None: + environment["INJECTED"] = "unparsable" + with container.injectable_configuration(InjectableTestConfiguration(current_value="Embed")) as injected: + # must have top precedence - over the environ provider. environ provider is returning a value that will cannot be parsed + # but the container provider has a precedence and the lookup in environ provider will never happen + C = make_configuration(EmbeddedWithInjectableConfiguration()) + assert C.injected.current_value == "Embed" + assert C.injected is injected + # remove first provider + container[ConfigProvidersListConfiguration].providers.pop(0) + # now environment will provide unparsable value + with pytest.raises(InvalidInitialValue): + C = make_configuration(EmbeddedWithInjectableConfiguration()) From 81eb730215e95186e160f711f7c1def7a24e6554 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 12 Oct 2022 22:47:16 +0200 Subject: [PATCH 29/66] adds config providers framewrok, injectable config and basic tests --- .../configuration/providers/configuration.py | 17 --- dlt/common/configuration/providers/environ.py | 8 +- .../configuration/providers/provider.py | 16 ++- .../specs/config_providers_configuration.py | 25 +++++ tests/common/configuration/__init__.py | 0 .../configuration/test_environ_provider.py | 106 ++++++++++++++++++ tests/common/configuration/test_providers.py | 17 +++ tests/common/configuration/utils.py | 95 ++++++++++++++++ 8 files changed, 262 insertions(+), 22 deletions(-) delete mode 100644 dlt/common/configuration/providers/configuration.py create mode 100644 dlt/common/configuration/specs/config_providers_configuration.py create mode 100644 tests/common/configuration/__init__.py create mode 100644 tests/common/configuration/test_environ_provider.py create mode 100644 tests/common/configuration/test_providers.py create mode 100644 tests/common/configuration/utils.py diff --git a/dlt/common/configuration/providers/configuration.py b/dlt/common/configuration/providers/configuration.py deleted file mode 100644 index cdfbaaac1f..0000000000 --- a/dlt/common/configuration/providers/configuration.py +++ /dev/null @@ -1,17 +0,0 @@ - - -from typing import List - -from dlt.common.configuration.providers import Provider -from dlt.common.configuration.providers.environ import EnvironProvider -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec - - -@configspec -class ConfigProvidersConfiguration(BaseConfiguration): - providers: List[Provider] - - def __init__(self) -> None: - super().__init__() - # add default providers - self.providers = [EnvironProvider()] diff --git a/dlt/common/configuration/providers/environ.py b/dlt/common/configuration/providers/environ.py index 02278cb057..2ea3df7a96 100644 --- a/dlt/common/configuration/providers/environ.py +++ b/dlt/common/configuration/providers/environ.py @@ -14,7 +14,7 @@ class EnvironProvider(Provider): def get_key_name(key: str, *namespaces: str) -> str: # env key is always upper case if namespaces: - namespaces = filter(lambda x: bool(x), namespaces) + namespaces = filter(lambda x: bool(x), namespaces) # type: ignore env_key = "__".join((*namespaces, key)) else: env_key = key @@ -50,5 +50,9 @@ def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Option return environ.get(key, None), key @property - def is_secret(self) -> bool: + def supports_secrets(self) -> bool: + return True + + @property + def supports_namespaces(self) -> bool: return True diff --git a/dlt/common/configuration/providers/provider.py b/dlt/common/configuration/providers/provider.py index 5257221560..0ecd69833c 100644 --- a/dlt/common/configuration/providers/provider.py +++ b/dlt/common/configuration/providers/provider.py @@ -4,8 +4,8 @@ class Provider(abc.ABC): - def __init__(self) -> None: - pass + # def __init__(self) -> None: + # pass @abc.abstractmethod def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: @@ -13,10 +13,20 @@ def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Option @property @abc.abstractmethod - def is_secret(self) -> bool: + def supports_secrets(self) -> bool: + pass + + @property + @abc.abstractmethod + def supports_namespaces(self) -> bool: pass @property @abc.abstractmethod def name(self) -> str: pass + + +def detect_known_providers() -> None: + # detects providers flagged + pass \ No newline at end of file diff --git a/dlt/common/configuration/specs/config_providers_configuration.py b/dlt/common/configuration/specs/config_providers_configuration.py new file mode 100644 index 0000000000..a10c258a19 --- /dev/null +++ b/dlt/common/configuration/specs/config_providers_configuration.py @@ -0,0 +1,25 @@ + + +from typing import List + +from dlt.common.configuration.providers import Provider +from dlt.common.configuration.container import ContainerInjectableConfiguration +from dlt.common.configuration.providers.environ import EnvironProvider +from dlt.common.configuration.providers.container import ContainerProvider +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec + + +@configspec +class ConfigProvidersListConfiguration(ContainerInjectableConfiguration): + providers: List[Provider] + + def __init__(self) -> None: + super().__init__() + # add default providers, ContainerProvider must be always first - it will provide injectable configs + self.providers = [ContainerProvider(), EnvironProvider()] + + +@configspec +class ConfigProvidersConfiguration(BaseConfiguration): + with_aws_secrets: bool = False + with_google_secrets: bool = False diff --git a/tests/common/configuration/__init__.py b/tests/common/configuration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/common/configuration/test_environ_provider.py b/tests/common/configuration/test_environ_provider.py new file mode 100644 index 0000000000..36285b0c41 --- /dev/null +++ b/tests/common/configuration/test_environ_provider.py @@ -0,0 +1,106 @@ +import pytest +from typing import Any + +from dlt.common.typing import TSecretValue +from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, resolve +from dlt.common.configuration.specs import RunConfiguration +from dlt.common.configuration.providers import environ as environ_provider + +from tests.utils import preserve_environ +from tests.common.configuration.utils import WrongConfiguration, SecretConfiguration, environment + + +@configspec +class SimpleConfiguration(RunConfiguration): + pipeline_name: str = "Some Name" + test_bool: bool = False + + +@configspec +class SecretKubeConfiguration(RunConfiguration): + pipeline_name: str = "secret kube" + secret_kube: TSecretValue = None + + +@configspec +class MockProdConfigurationVar(RunConfiguration): + pipeline_name: str = "comp" + + + +def test_resolves_from_environ(environment: Any) -> None: + environment["NONECONFIGVAR"] = "Some" + + C = WrongConfiguration() + resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) + assert not C.is_partial() + + assert C.NoneConfigVar == environment["NONECONFIGVAR"] + + +def test_resolves_from_environ_with_coercion(environment: Any) -> None: + environment["TEST_BOOL"] = 'yes' + + C = SimpleConfiguration() + resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) + assert not C.is_partial() + + # value will be coerced to bool + assert C.test_bool is True + + +def test_secret(environment: Any) -> None: + with pytest.raises(ConfigEntryMissingException): + resolve.make_configuration(SecretConfiguration()) + environment['SECRET_VALUE'] = "1" + C = resolve.make_configuration(SecretConfiguration()) + assert C.secret_value == "1" + # mock the path to point to secret storage + # from dlt.common.configuration import config_utils + path = environ_provider.SECRET_STORAGE_PATH + del environment['SECRET_VALUE'] + try: + # must read a secret file + environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" + C = resolve.make_configuration(SecretConfiguration()) + assert C.secret_value == "BANANA" + + # set some weird path, no secret file at all + del environment['SECRET_VALUE'] + environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" + with pytest.raises(ConfigEntryMissingException): + resolve.make_configuration(SecretConfiguration()) + + # set env which is a fallback for secret not as file + environment['SECRET_VALUE'] = "1" + C = resolve.make_configuration(SecretConfiguration()) + assert C.secret_value == "1" + finally: + environ_provider.SECRET_STORAGE_PATH = path + + +def test_secret_kube_fallback(environment: Any) -> None: + path = environ_provider.SECRET_STORAGE_PATH + try: + environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" + C = resolve.make_configuration(SecretKubeConfiguration()) + # all unix editors will add x10 at the end of file, it will be preserved + assert C.secret_kube == "kube\n" + # we propagate secrets back to environ and strip the whitespace + assert environment['SECRET_KUBE'] == "kube" + finally: + environ_provider.SECRET_STORAGE_PATH = path + + +def test_configuration_files(environment: Any) -> None: + # overwrite config file paths + environment["CONFIG_FILES_STORAGE_PATH"] = "./tests/common/cases/schemas/ev1/%s" + C = resolve.make_configuration(MockProdConfigurationVar()) + assert C.config_files_storage_path == environment["CONFIG_FILES_STORAGE_PATH"] + assert C.has_configuration_file("hasn't") is False + assert C.has_configuration_file("event_schema.json") is True + assert C.get_configuration_file_path("event_schema.json") == "./tests/common/cases/schemas/ev1/event_schema.json" + with C.open_configuration_file("event_schema.json", "r") as f: + f.read() + with pytest.raises(ConfigFileNotFoundException): + C.open_configuration_file("hasn't", "r") diff --git a/tests/common/configuration/test_providers.py b/tests/common/configuration/test_providers.py new file mode 100644 index 0000000000..2e88a7af58 --- /dev/null +++ b/tests/common/configuration/test_providers.py @@ -0,0 +1,17 @@ +def test_providers_order() -> None: + pass + + +def test_add_remove_providers() -> None: + # TODO: we should be able to add and remove providers + pass + + +def test_providers_autodetect_and_config() -> None: + # TODO: toml based and remote vaults should be configured and/or autodetected + pass + + +def test_providers_value_getter() -> None: + # TODO: it should be possible to get a value from providers' chain via `config` and `secrets` objects via indexer (nested) or explicit key, *namespaces getter + pass \ No newline at end of file diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py new file mode 100644 index 0000000000..430b971491 --- /dev/null +++ b/tests/common/configuration/utils.py @@ -0,0 +1,95 @@ +import pytest +from os import environ +from typing import Any, List, Optional, Tuple, Type +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.config_providers_configuration import ConfigProvidersListConfiguration + +from dlt.common.typing import TSecretValue +from dlt.common.configuration import configspec +from dlt.common.configuration.providers import Provider +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration, RunConfiguration + + +@configspec +class WrongConfiguration(RunConfiguration): + pipeline_name: str = "Some Name" + NoneConfigVar: str = None + log_color: bool = True + + +@configspec +class SecretConfiguration(BaseConfiguration): + secret_value: TSecretValue = None + + +@configspec +class SecretCredentials(CredentialsConfiguration): + secret_value: TSecretValue = None + + +@configspec +class WithCredentialsConfiguration(BaseConfiguration): + credentials: SecretCredentials + + +@configspec +class NamespacedConfiguration(BaseConfiguration): + __namespace__ = "DLT_TEST" + + password: str = None + + +@pytest.fixture(scope="function") +def environment() -> Any: + environ.clear() + return environ + + +@pytest.fixture(scope="function") +def mock_provider() -> "MockProvider": + container = Container() + with container.injectable_configuration(ConfigProvidersListConfiguration()) as providers: + # replace all providers with MockProvider that does not support secrets + mock_provider = MockProvider() + providers.providers = [mock_provider] + yield mock_provider + + +class MockProvider(Provider): + + def __init__(self) -> None: + self.value: Any = None + self.return_value_on: Tuple[str] = () + self.reset_stats() + + def reset_stats(self) -> None: + self.last_namespace: Tuple[str] = None + self.last_namespaces: List[Tuple[str]] = [] + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + self.last_namespace = namespaces + self.last_namespaces.append(namespaces) + print("|".join(namespaces) + "-" + key) + if namespaces == self.return_value_on: + rv = self.value + else: + rv = None + return rv, "|".join(namespaces) + "-" + key + + @property + def supports_secrets(self) -> bool: + return False + + @property + def supports_namespaces(self) -> bool: + return True + + @property + def name(self) -> str: + return "Mock Provider" + + +class SecretMockProvider(MockProvider): + @property + def supports_secrets(self) -> bool: + return True From f08b8d5b81fdda10baee3407b12d7554963ee804 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 12 Oct 2022 22:48:59 +0200 Subject: [PATCH 30/66] adds namespaces, pipeline name, injectable namespace config to config resolver, typing improvements and tests --- dlt/common/configuration/exceptions.py | 23 +- dlt/common/configuration/resolve.py | 126 +++++---- .../configuration/specs/base_configuration.py | 32 ++- .../specs/normalize_volume_configuration.py | 2 +- dlt/common/logger.py | 2 +- dlt/common/typing.py | 19 +- dlt/dbt_runner/configuration.py | 8 +- .../configuration/test_configuration.py | 242 ++++++------------ tests/common/configuration/test_namespaces.py | 179 +++++++++++++ tests/common/test_typing.py | 2 +- tests/conftest.py | 22 +- 11 files changed, 413 insertions(+), 244 deletions(-) create mode 100644 tests/common/configuration/test_namespaces.py diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index ce13c949bb..f64ce00c39 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Mapping, Type, Union, NamedTuple, Sequence +from typing import Any, Mapping, Type, Union, NamedTuple, Sequence from dlt.common.exceptions import DltException @@ -28,9 +28,9 @@ def __init__(self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]]) self.spec_name = spec_name msg = f"Following fields are missing: {str(list(traces.keys()))} in configuration with spec {spec_name}\n" - for f, traces in traces.items(): - msg += f'\tfor field "{f}" config providers and keys were tried in following order\n' - for tr in traces: + for f, field_traces in traces.items(): + msg += f'\tfor field "{f}" config providers and keys were tried in following order:\n' + for tr in field_traces: msg += f'\t\tIn {tr.provider} key {tr.key} was not found.\n' super().__init__(msg) @@ -85,3 +85,18 @@ def __init__(self, provider_name: str, key: str) -> None: self.provider_name = provider_name self.key = key super().__init__(f"Provider {provider_name} cannot hold secret values but key {key} with secret value is present") + + +class InvalidInitialValue(ConfigurationException): + def __init__(self, spec: Type[Any], initial_value_type: Type[Any]) -> None: + self.spec = spec + self.initial_value_type = initial_value_type + super().__init__(f"Initial value of type {initial_value_type} is not valid for {spec.__name__}") + + +class ContainerInjectableConfigurationMangled(ConfigurationException): + def __init__(self, spec: Type[Any], existing_config: Any, expected_config: Any) -> None: + self.spec = spec + self.existing_config = existing_config + self.expected_config = expected_config + super().__init__(f"When restoring injectable config {spec.__name__}, instance {expected_config} was expected, instead instance {existing_config} was found.") diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 510d7f4e24..e0347d460f 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -1,25 +1,26 @@ import ast -import dataclasses import inspect import sys import semver +import dataclasses from collections.abc import Mapping as C_Mapping -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, get_origin +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, get_origin from dlt.common import json, logger from dlt.common.typing import TSecretValue, is_optional_type, extract_inner_type from dlt.common.schema.utils import coerce_type, py_type_to_sc_type from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration, configspec -from dlt.common.configuration.inject import Container -from dlt.common.configuration.providers.configuration import ConfigProvidersConfiguration -from dlt.common.configuration.exceptions import (LookupTrace, ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException, ValueNotSecretException) +from dlt.common.configuration.container import Container, ContainerInjectableConfiguration +from dlt.common.configuration.specs.config_providers_configuration import ConfigProvidersListConfiguration +from dlt.common.configuration.providers.container import ContainerProvider +from dlt.common.configuration.exceptions import (LookupTrace, ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException, ValueNotSecretException, InvalidInitialValue) CHECK_INTEGRITY_F: str = "check_integrity" TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) -def make_configuration(config: TConfiguration, initial_value: Any = None, accept_partial: bool = False) -> TConfiguration: +def make_configuration(config: TConfiguration, *, namespaces: Tuple[str, ...] = (), initial_value: Any = None, accept_partial: bool = False) -> TConfiguration: if not isinstance(config, BaseConfiguration): raise ConfigurationWrongTypeException(type(config)) @@ -36,7 +37,7 @@ def make_configuration(config: TConfiguration, initial_value: Any = None, accept raise InvalidInitialValue(type(config), type(initial_value)) try: - _resolve_config_fields(config, accept_partial) + _resolve_config_fields(config, namespaces, accept_partial) _check_configuration_integrity(config) # full configuration was resolved config.__is_resolved__ = True @@ -96,7 +97,7 @@ def _add_module_version(config: BaseConfiguration) -> None: pass -def _resolve_config_fields(config: BaseConfiguration, accept_partial: bool) -> None: +def _resolve_config_fields(config: BaseConfiguration, namespaces: Tuple[str, ...], accept_partial: bool) -> None: fields = config.get_resolvable_fields() unresolved_fields: Dict[str, Sequence[LookupTrace]] = {} @@ -109,15 +110,18 @@ def _resolve_config_fields(config: BaseConfiguration, accept_partial: bool) -> N accept_partial = accept_partial or is_optional # if actual value is BaseConfiguration, resolve that instance if isinstance(current_value, BaseConfiguration): - current_value = make_configuration(current_value, accept_partial=accept_partial) + # add key as innermost namespace + current_value = make_configuration(current_value, namespaces=namespaces + (key,), accept_partial=accept_partial) else: # resolve key value via active providers - value, traces = _resolve_single_field(key, hint, config.__namespace__) + value, traces = _resolve_single_field(key, hint, config.__namespace__, *namespaces) # log trace if logger.is_logging() and logger.log_level() == "DEBUG": logger.debug(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") + # print(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") for tr in traces: + # print(str(tr)) logger.debug(str(tr)) # extract hint from Optional / Literal / NewType hints @@ -126,8 +130,12 @@ def _resolve_config_fields(config: BaseConfiguration, accept_partial: bool) -> N hint = get_origin(hint) or hint # if hint is BaseConfiguration then resolve it recursively if inspect.isclass(hint) and issubclass(hint, BaseConfiguration): - # create new instance and pass value from the provider as initial - current_value = make_configuration(hint(), initial_value=value or current_value, accept_partial=accept_partial) + if isinstance(value, BaseConfiguration): + # if value is base configuration already (ie. via ContainerProvider) return it directly + current_value = value + else: + # create new instance and pass value from the provider as initial, add key to namespaces + current_value = make_configuration(hint(), namespaces=namespaces + (key,), initial_value=value or current_value, accept_partial=accept_partial) else: if value is not None: current_value = deserialize_value(key, value, hint) @@ -140,17 +148,6 @@ def _resolve_config_fields(config: BaseConfiguration, accept_partial: bool) -> N raise ConfigEntryMissingException(type(config).__name__, unresolved_fields) -# def _is_config_bounded(config: BaseConfiguration, fields: Mapping[str, type]) -> None: -# # TODO: here we assume all keys are taken from environ provider, that should change when we introduce more providers -# # environ.get_key_name(key, config.__namespace__) -# _unbound_attrs = [ -# key for key in fields if getattr(config, key) is None and not is_optional_type(fields[key]) -# ] - -# if len(_unbound_attrs) > 0: -# raise ConfigEntryMissingException(_unbound_attrs, config.__namespace__) - - def _check_configuration_integrity(config: BaseConfiguration) -> None: # python multi-inheritance is cooperative and this would require that all configurations cooperatively # call each other check_integrity. this is not at all possible as we do not know which configs in the end will @@ -165,42 +162,69 @@ def _check_configuration_integrity(config: BaseConfiguration) -> None: c.__dict__[CHECK_INTEGRITY_F](config) -def _get_resolvable_fields(config: BaseConfiguration) -> Dict[str, type]: - return {f.name:f.type for f in dataclasses.fields(config) if not f.name.startswith("__")} - - -@configspec -class ConfigNamespacesConfiguration(BaseConfiguration): +@configspec(init=True) +class ConfigNamespacesConfiguration(ContainerInjectableConfiguration): pipeline_name: Optional[str] - namespaces: List[str] + namespaces: List[str] = dataclasses.field(default_factory=lambda: []) - def __init__(self) -> None: - super().__init__() - self.namespaces = [] - -def _resolve_single_field(key: str, hint: Type[Any], namespace: str, *namespaces: str) -> Tuple[Optional[Any], List[LookupTrace]]: +def _resolve_single_field(key: str, hint: Type[Any], config_namespace: str, *namespaces: str) -> Tuple[Optional[Any], List[LookupTrace]]: + container = Container() # get providers from container - providers = Container()[ConfigProvidersConfiguration].providers + providers = container[ConfigProvidersListConfiguration].providers # get additional namespaces to look in from container - context_namespaces = Container()[ConfigNamespacesConfiguration].namespaces + ctx_namespaces = container[ConfigNamespacesConfiguration] + # pipeline_name = ctx_namespaces.pipeline_name # start looking from the top provider with most specific set of namespaces first traces: List[LookupTrace] = [] value = None - ns = [*namespaces, *context_namespaces] - for provider in providers: - while True: - # first namespace always present - _ns_t = (namespace, *ns) if namespace else ns - value, ns_key = provider.get_value(key, hint, *_ns_t) - # create trace - traces.append(LookupTrace(provider.name, _ns_t, ns_key, value)) - # if secret is obtained from non secret provider, we must fail - if value is not None and not provider.is_secret and (hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration))): - raise ValueNotSecretException(provider.name, ns_key) - if len(ns) == 0 or value is not None: - break - ns.pop() + + def look_namespaces(pipeline_name: str = None) -> Any: + for provider in providers: + if provider.supports_namespaces: + ns = [*ctx_namespaces.namespaces, *namespaces] + else: + # if provider does not support namespaces and pipeline name is set then ignore it + if pipeline_name: + continue + else: + # pass empty namespaces + ns = [] + + value = None + while True: + if pipeline_name or config_namespace: + full_ns = ns.copy() + # pipeline, when provided, is the most outer and always present + if pipeline_name: + full_ns.insert(0, pipeline_name) + # config namespace, when provided, is innermost and always present + if config_namespace and provider.supports_namespaces: + full_ns.append(config_namespace) + else: + full_ns = ns + value, ns_key = provider.get_value(key, hint, *full_ns) + # create trace, ignore container provider + if provider.name != ContainerProvider.NAME: + traces.append(LookupTrace(provider.name, full_ns, ns_key, value)) + # if secret is obtained from non secret provider, we must fail + if value is not None and not provider.supports_secrets and (hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration))): + raise ValueNotSecretException(provider.name, ns_key) + if value is not None: + # value found, ignore other providers + return value + if len(ns) == 0: + # check next provider + break + # pop optional namespaces for less precise lookup + ns.pop() + + # first try with pipeline name as namespace, if present + if ctx_namespaces.pipeline_name: + value = look_namespaces(ctx_namespaces.pipeline_name) + # then without it + if value is None: + value = look_namespaces() return value, traces diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 57ed3c78c5..9895052e08 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -1,13 +1,14 @@ import contextlib import dataclasses -from typing import Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING, get_origin + +from typing import Callable, Optional, Union, Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING, get_origin, overload if TYPE_CHECKING: TDtcField = dataclasses.Field[Any] else: TDtcField = dataclasses.Field -from dlt.common.typing import TAny, extract_inner_type, is_optional_type +from dlt.common.typing import TAnyClass, extract_inner_type, is_optional_type from dlt.common.schema.utils import py_type_to_sc_type from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported @@ -25,9 +26,25 @@ def is_valid_hint(hint: Type[Any]) -> bool: return False -def configspec(cls: Type[TAny] = None, /, *, init: bool = False) -> Type[TAny]: +@overload +def configspec(cls: Type[TAnyClass], /, *, init: bool = False) -> Type[TAnyClass]: + ... + + +@overload +def configspec(cls: None = ..., /, *, init: bool = False) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: + ... - def wrap(cls: Type[TAny]) -> Type[TAny]: + +def configspec(cls: Optional[Type[Any]] = None, /, *, init: bool = False) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]: + + def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: + # if type does not derive from BaseConfiguration then derive it + with contextlib.suppress(NameError): + if not issubclass(cls, BaseConfiguration): + # keep the original module + fields = {"__module__": cls.__module__, "__annotations__": getattr(cls, "__annotations__", {})} + cls = type(cls.__name__, (cls, BaseConfiguration), fields) # get all annotations without corresponding attributes and set them to None for ann in cls.__annotations__: if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_impl")): @@ -40,11 +57,12 @@ def wrap(cls: Type[TAny]) -> Type[TAny]: hint = cls.__annotations__[att_name] if not is_valid_hint(hint): raise ConfigFieldTypeHintNotSupported(att_name, cls, hint) - return dataclasses.dataclass(cls, init=init, eq=False) # type: ignore + # do not generate repr as it may contain secret values + return dataclasses.dataclass(cls, init=init, eq=False, repr=False) # type: ignore # called with parenthesis if cls is None: - return wrap # type: ignore + return wrap return wrap(cls) @@ -142,4 +160,6 @@ def __fields_dict(self) -> Dict[str, TDtcField]: @configspec class CredentialsConfiguration(BaseConfiguration): + """Base class for all credentials. Credentials are configurations that may be stored only by providers supporting secrets.""" pass + diff --git a/dlt/common/configuration/specs/normalize_volume_configuration.py b/dlt/common/configuration/specs/normalize_volume_configuration.py index 584f271169..e1f2946947 100644 --- a/dlt/common/configuration/specs/normalize_volume_configuration.py +++ b/dlt/common/configuration/specs/normalize_volume_configuration.py @@ -1,6 +1,6 @@ from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec -@configspec +@configspec(init=True) class NormalizeVolumeConfiguration(BaseConfiguration): normalize_volume_path: str = None # path to volume where normalized loader files will be stored diff --git a/dlt/common/logger.py b/dlt/common/logger.py index 2d55636f43..eecf86ede6 100644 --- a/dlt/common/logger.py +++ b/dlt/common/logger.py @@ -218,7 +218,7 @@ def is_logging() -> bool: def log_level() -> str: if not LOGGER: raise RuntimeError("Logger not initialized") - return logging.getLevelName(LOGGER.level) + return logging.getLevelName(LOGGER.level) # type: ignore def is_json_logging(log_format: str) -> bool: diff --git a/dlt/common/typing.py b/dlt/common/typing.py index aef5b9f245..71058093d5 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,6 +1,10 @@ from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from re import Pattern as _REPattern from typing import Callable, Dict, Any, Literal, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, Iterable, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin +try: + from typing_extensions import ParamSpec, TypeAlias, TypeGuard +except ImportError: + ParamSpec = lambda x: [x] # type: ignore if TYPE_CHECKING: from _typeshed import StrOrBytesPath from typing import _TypedDict @@ -10,13 +14,15 @@ from typing import _TypedDictMeta as _TypedDict REPattern = _REPattern -DictStrAny = Dict[str, Any] -DictStrStr = Dict[str, str] -StrAny = Mapping[str, Any] # immutable, covariant entity -StrStr = Mapping[str, str] # immutable, covariant entity -StrStrStr = Mapping[str, Mapping[str, str]] # immutable, covariant entity -TFun = TypeVar("TFun", bound=Callable[..., Any]) +DictStrAny: TypeAlias = Dict[str, Any] +DictStrStr: TypeAlias = Dict[str, str] +StrAny: TypeAlias = Mapping[str, Any] # immutable, covariant entity +StrStr: TypeAlias = Mapping[str, str] # immutable, covariant entity +StrStrStr: TypeAlias = Mapping[str, Mapping[str, str]] # immutable, covariant entity +AnyFun: TypeAlias = Callable[..., Any] +TFun = TypeVar("TFun", bound=AnyFun) # any function TAny = TypeVar("TAny", bound=Any) +TAnyClass = TypeVar("TAnyClass", bound=object) TSecretValue = NewType("TSecretValue", str) # represent secret value ie. coming from Kubernetes/Docker secrets or other providers TDataItem = Any # a single data item extracted from data source, normalized and loaded @@ -89,4 +95,3 @@ def extract_inner_type(hint: Type[Any]) -> Type[Any]: # descend into supertypes of NewType return extract_inner_type(hint.__supertype__) return hint - diff --git a/dlt/dbt_runner/configuration.py b/dlt/dbt_runner/configuration.py index 3244535977..7516e6eff6 100644 --- a/dlt/dbt_runner/configuration.py +++ b/dlt/dbt_runner/configuration.py @@ -1,9 +1,10 @@ import dataclasses +from os import environ from typing import List, Optional, Type from dlt.common.typing import StrAny, TSecretValue from dlt.common.configuration import make_configuration, configspec -from dlt.common.configuration.providers import environ +from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials from . import __version__ @@ -39,10 +40,11 @@ def check_integrity(self) -> None: def gen_configuration_variant(initial_values: StrAny = None) -> DBTRunnerConfiguration: # derive concrete config depending on env vars present DBTRunnerConfigurationImpl: Type[DBTRunnerConfiguration] + environ = EnvironProvider() - source_schema_prefix = environ.get_key("default_dataset", type(str)) + source_schema_prefix: str = environ.get_value("default_dataset", type(str)) # type: ignore - if environ.get_key("project_id", type(str), namespace=GcpClientCredentials.__namespace__): + if environ.get_value("project_id", type(str), GcpClientCredentials.__namespace__): @configspec class DBTRunnerConfigurationPostgres(PostgresCredentials, DBTRunnerConfiguration): SOURCE_SCHEMA_PREFIX: str = source_schema_prefix diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index c19f079ea8..1fef49012a 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -1,17 +1,18 @@ import pytest -from os import environ import datetime # noqa: I251 from typing import Any, Dict, List, Mapping, MutableMapping, NewType, Optional, Sequence, Tuple, Type from dlt.common import pendulum, Decimal, Wei -from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, LookupTrace +from dlt.common.utils import custom_environ from dlt.common.typing import StrAny, TSecretValue, extract_inner_type -from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, ConfigEnvValueCannotBeCoercedException, resolve +from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, InvalidInitialValue, LookupTrace, ValueNotSecretException +from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, resolve from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration +from dlt.common.configuration.specs.base_configuration import is_valid_hint from dlt.common.configuration.providers import environ as environ_provider -from dlt.common.utils import custom_environ from tests.utils import preserve_environ, add_config_dict_to_env +from tests.common.configuration.utils import MockProvider, WithCredentialsConfiguration, WrongConfiguration, SecretConfiguration, NamespacedConfiguration, environment, mock_provider # used to test version __version__ = "1.0.5" @@ -70,31 +71,6 @@ } -@configspec -class SimpleConfiguration(RunConfiguration): - pipeline_name: str = "Some Name" - test_bool: bool = False - - -@configspec -class WrongConfiguration(RunConfiguration): - pipeline_name: str = "Some Name" - NoneConfigVar: str = None - log_color: bool = True - - -@configspec -class SecretConfiguration(RunConfiguration): - pipeline_name: str = "secret" - secret_value: TSecretValue = None - - -@configspec -class SecretKubeConfiguration(RunConfiguration): - pipeline_name: str = "secret kube" - secret_kube: TSecretValue = None - - @configspec class CoercionTestConfiguration(RunConfiguration): pipeline_name: str = "Some Name" @@ -117,7 +93,6 @@ class CoercionTestConfiguration(RunConfiguration): mutable_mapping_val: MutableMapping[str, str] = None - @configspec class VeryWrongConfiguration(WrongConfiguration): pipeline_name: str = "Some Name" @@ -145,23 +120,10 @@ class MockProdConfiguration(RunConfiguration): pipeline_name: str = "comp" -@configspec -class MockProdConfigurationVar(RunConfiguration): - pipeline_name: str = "comp" - - -@configspec -class NamespacedConfiguration(BaseConfiguration): - __namespace__ = "DLT_TEST" - - password: str = None - - @configspec(init=True) class FieldWithNoDefaultConfiguration(RunConfiguration): no_default: str - @configspec(init=True) class InstrumentedConfiguration(BaseConfiguration): head: str @@ -201,12 +163,6 @@ class EmbeddedOptionalConfiguration(BaseConfiguration): SecondOrderStr = NewType("SecondOrderStr", FirstOrderStr) -@pytest.fixture(scope="function") -def environment() -> Any: - environ.clear() - return environ - - def test_initial_config_state() -> None: assert BaseConfiguration.__is_resolved__ is False assert BaseConfiguration.__namespace__ is None @@ -231,6 +187,14 @@ def test_set_initial_config_value(environment: Any) -> None: assert C.to_native_representation() == "h>tu>be>xhe" +def test_invalid_initial_config_value() -> None: + # 2137 cannot be parsed and also is not a dict that can initialize the fields + with pytest.raises(InvalidInitialValue) as py_ex: + resolve.make_configuration(InstrumentedConfiguration(), initial_value=2137) + assert py_ex.value.spec is InstrumentedConfiguration + assert py_ex.value.initial_value_type is int + + def test_check_integrity(environment: Any) -> None: with pytest.raises(RuntimeError): # head over hells @@ -294,6 +258,14 @@ def test_run_configuration_gen_name(environment: Any) -> None: def test_configuration_is_mutable_mapping(environment: Any) -> None: + + + @configspec + class _SecretCredentials(RunConfiguration): + pipeline_name: Optional[str] = "secret" + secret_value: TSecretValue = None + + # configurations provide full MutableMapping support # here order of items in dict matters expected_dict = { @@ -304,12 +276,12 @@ def test_configuration_is_mutable_mapping(environment: Any) -> None: 'log_level': 'DEBUG', 'request_timeout': (15, 300), 'config_files_storage_path': '_storage/config/%s', - 'secret_value': None + "secret_value": None } - assert dict(SecretConfiguration()) == expected_dict + assert dict(_SecretCredentials()) == expected_dict environment["SECRET_VALUE"] = "secret" - C = resolve.make_configuration(SecretConfiguration()) + C = resolve.make_configuration(_SecretCredentials()) expected_dict["secret_value"] = "secret" assert dict(C) == expected_dict @@ -424,27 +396,6 @@ def test_accepts_optional_missing_fields(environment: Any) -> None: assert C.instrumented.is_partial() -def test_resolves_from_environ(environment: Any) -> None: - environment["NONECONFIGVAR"] = "Some" - - C = WrongConfiguration() - resolve._resolve_config_fields(C, accept_partial=False) - assert not C.is_partial() - - assert C.NoneConfigVar == environment["NONECONFIGVAR"] - - -def test_resolves_from_environ_with_coercion(environment: Any) -> None: - environment["TEST_BOOL"] = 'yes' - - C = SimpleConfiguration() - resolve._resolve_config_fields(C, accept_partial=False) - assert not C.is_partial() - - # value will be coerced to bool - assert C.test_bool is True - - def test_find_all_keys() -> None: keys = VeryWrongConfiguration().get_resolvable_fields() # assert hints and types: LOG_COLOR had it hint overwritten in derived class @@ -455,7 +406,7 @@ def test_coercion_to_hint_types(environment: Any) -> None: add_config_dict_to_env(COERCIONS) C = CoercionTestConfiguration() - resolve._resolve_config_fields(C, accept_partial=False) + resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) for key in COERCIONS: assert getattr(C, key) == COERCIONS[key] @@ -494,7 +445,7 @@ def test_invalid_coercions(environment: Any) -> None: add_config_dict_to_env(INVALID_COERCIONS) for key, value in INVALID_COERCIONS.items(): try: - resolve._resolve_config_fields(C, accept_partial=False) + resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) except ConfigEnvValueCannotBeCoercedException as coerc_exc: # must fail exactly on expected value if coerc_exc.field_name != key: @@ -509,7 +460,7 @@ def test_excepted_coercions(environment: Any) -> None: C = CoercionTestConfiguration() add_config_dict_to_env(COERCIONS) add_config_dict_to_env(EXCEPTED_COERCIONS, overwrite_keys=True) - resolve._resolve_config_fields(C, accept_partial=False) + resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) for key in EXCEPTED_COERCIONS: assert getattr(C, key) == COERCED_EXCEPTIONS[key] @@ -533,9 +484,6 @@ class NoHintConfiguration(BaseConfiguration): NoHintConfiguration() - - - def test_make_configuration(environment: Any) -> None: # fill up configuration environment["NONECONFIGVAR"] = "1" @@ -544,7 +492,7 @@ def test_make_configuration(environment: Any) -> None: assert C.NoneConfigVar == "1" -def test_auto_derivation(environment: Any) -> None: +def test_dataclass_instantiation(environment: Any) -> None: # make_configuration works on instances of dataclasses and types are not modified environment['SECRET_VALUE'] = "1" C = resolve.make_configuration(SecretConfiguration()) @@ -560,7 +508,7 @@ def test_initial_values(environment: Any) -> None: environment["CREATED_VAL"] = "12837" # set initial values and allow partial config C = resolve.make_configuration(CoercionTestConfiguration(), - {"pipeline_name": "initial name", "none_val": type(environment), "created_val": 878232, "bytes_val": b"str"}, + initial_value={"pipeline_name": "initial name", "none_val": type(environment), "created_val": 878232, "bytes_val": b"str"}, accept_partial=True ) # from env @@ -587,60 +535,17 @@ def test_finds_version(environment: Any) -> None: global __version__ v = __version__ - C = resolve.make_configuration(SimpleConfiguration()) + C = resolve.make_configuration(BaseConfiguration()) assert C._version == v try: del globals()["__version__"] - C = resolve.make_configuration(SimpleConfiguration()) + C = resolve.make_configuration(BaseConfiguration()) assert not hasattr(C, "_version") finally: __version__ = v -def test_secret(environment: Any) -> None: - with pytest.raises(ConfigEntryMissingException): - resolve.make_configuration(SecretConfiguration()) - environment['SECRET_VALUE'] = "1" - C = resolve.make_configuration(SecretConfiguration()) - assert C.secret_value == "1" - # mock the path to point to secret storage - # from dlt.common.configuration import config_utils - path = environ_provider.SECRET_STORAGE_PATH - del environment['SECRET_VALUE'] - try: - # must read a secret file - environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = resolve.make_configuration(SecretConfiguration()) - assert C.secret_value == "BANANA" - - # set some weird path, no secret file at all - del environment['SECRET_VALUE'] - environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" - with pytest.raises(ConfigEntryMissingException): - resolve.make_configuration(SecretConfiguration()) - - # set env which is a fallback for secret not as file - environment['SECRET_VALUE'] = "1" - C = resolve.make_configuration(SecretConfiguration()) - assert C.secret_value == "1" - finally: - environ_provider.SECRET_STORAGE_PATH = path - - -def test_secret_kube_fallback(environment: Any) -> None: - path = environ_provider.SECRET_STORAGE_PATH - try: - environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = resolve.make_configuration(SecretKubeConfiguration()) - # all unix editors will add x10 at the end of file, it will be preserved - assert C.secret_kube == "kube\n" - # we propagate secrets back to environ and strip the whitespace - assert environment['SECRET_KUBE'] == "kube" - finally: - environ_provider.SECRET_STORAGE_PATH = path - - -def test_coerce_values() -> None: +def test_coercion_rules() -> None: with pytest.raises(ConfigEnvValueCannotBeCoercedException): coerce_single_value("key", "some string", int) assert coerce_single_value("key", "some string", str) == "some string" @@ -664,42 +569,53 @@ def test_coerce_values() -> None: coerce_single_value("key", "some string", Optional[LongInteger]) # type: ignore -def test_configuration_files(environment: Any) -> None: - # overwrite config file paths - environment["CONFIG_FILES_STORAGE_PATH"] = "./tests/common/cases/schemas/ev1/%s" - C = resolve.make_configuration(MockProdConfigurationVar()) - assert C.config_files_storage_path == environment["CONFIG_FILES_STORAGE_PATH"] - assert C.has_configuration_file("hasn't") is False - assert C.has_configuration_file("event_schema.json") is True - assert C.get_configuration_file_path("event_schema.json") == "./tests/common/cases/schemas/ev1/event_schema.json" - with C.open_configuration_file("event_schema.json", "r") as f: - f.read() - with pytest.raises(ConfigFileNotFoundException): - C.open_configuration_file("hasn't", "r") - - -def test_namespaced_configuration(environment: Any) -> None: - with pytest.raises(ConfigEntryMissingException) as exc_val: - resolve.make_configuration(NamespacedConfiguration()) - assert list(exc_val.value.traces.keys()) == ["password"] - assert exc_val.value.spec_name == "NamespacedConfiguration" - # check trace - traces = exc_val.value.traces["password"] - # only one provider and namespace was tried - assert len(traces) == 1 - assert traces[0] == LookupTrace("Environment Variables", ["DLT_TEST"], "DLT_TEST__PASSWORD", None) - - # init vars work without namespace - C = resolve.make_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) - assert C.password == "PASS" - - # env var must be prefixed - environment["PASSWORD"] = "PASS" - with pytest.raises(ConfigEntryMissingException) as exc_val: - resolve.make_configuration(NamespacedConfiguration()) - environment["DLT_TEST__PASSWORD"] = "PASS" - C = resolve.make_configuration(NamespacedConfiguration()) - assert C.password == "PASS" +def test_is_valid_hint() -> None: + assert is_valid_hint(Any) is True + assert is_valid_hint(Optional[Any]) is True + assert is_valid_hint(RunConfiguration) is True + assert is_valid_hint(Optional[RunConfiguration]) is True + assert is_valid_hint(TSecretValue) is True + assert is_valid_hint(Optional[TSecretValue]) is True + # in case of generics, origin will be used and args are not checked + assert is_valid_hint(MutableMapping[TSecretValue, Any]) is True + # this is valid (args not checked) + assert is_valid_hint(MutableMapping[TSecretValue, ConfigEnvValueCannotBeCoercedException]) is True + assert is_valid_hint(Wei) is True + # any class type, except deriving from BaseConfiguration is wrong type + assert is_valid_hint(ConfigEntryMissingException) is False + + +def test_configspec_auto_base_config_derivation() -> None: + + @configspec(init=True) + class AutoBaseDerivationConfiguration: + auto: str + + assert issubclass(AutoBaseDerivationConfiguration, BaseConfiguration) + assert hasattr(AutoBaseDerivationConfiguration, "auto") + + assert AutoBaseDerivationConfiguration().auto is None + assert AutoBaseDerivationConfiguration(auto="auto").auto == "auto" + assert AutoBaseDerivationConfiguration(auto="auto").get_resolvable_fields() == {"auto": str} + # we preserve original module + assert AutoBaseDerivationConfiguration.__module__ == __name__ + assert not hasattr(BaseConfiguration, "auto") + + +def test_secret_value_not_secret_provider(mock_provider: MockProvider) -> None: + mock_provider.value = "SECRET" + + # TSecretValue will fail + with pytest.raises(ValueNotSecretException) as py_ex: + resolve.make_configuration(SecretConfiguration(), namespaces=("mock",)) + assert py_ex.value.provider_name == "Mock Provider" + assert py_ex.value.key == "-secret_value" + + # anything derived from CredentialsConfiguration will fail + with pytest.raises(ValueNotSecretException) as py_ex: + resolve.make_configuration(WithCredentialsConfiguration(), namespaces=("mock",)) + assert py_ex.value.provider_name == "Mock Provider" + assert py_ex.value.key == "-credentials" def coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: diff --git a/tests/common/configuration/test_namespaces.py b/tests/common/configuration/test_namespaces.py new file mode 100644 index 0000000000..d79cf9a2f8 --- /dev/null +++ b/tests/common/configuration/test_namespaces.py @@ -0,0 +1,179 @@ +from unittest import mock +import pytest +from typing import Any, Optional +from dlt.common.configuration.container import Container + +from dlt.common.typing import TSecretValue +from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, resolve +from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.configuration.providers import environ as environ_provider +from dlt.common.configuration.exceptions import LookupTrace + +from tests.utils import preserve_environ +from tests.common.configuration.utils import MockProvider, WrongConfiguration, SecretConfiguration, NamespacedConfiguration, environment, mock_provider + + +@configspec +class SingleValConfiguration(BaseConfiguration): + sv: str + + +@configspec +class EmbeddedConfiguration(BaseConfiguration): + sv_config: Optional[SingleValConfiguration] + + +def test_namespaced_configuration(environment: Any) -> None: + with pytest.raises(ConfigEntryMissingException) as exc_val: + resolve.make_configuration(NamespacedConfiguration()) + assert list(exc_val.value.traces.keys()) == ["password"] + assert exc_val.value.spec_name == "NamespacedConfiguration" + # check trace + traces = exc_val.value.traces["password"] + # only one provider and namespace was tried + assert len(traces) == 1 + assert traces[0] == LookupTrace("Environment Variables", ["DLT_TEST"], "DLT_TEST__PASSWORD", None) + + # init vars work without namespace + C = resolve.make_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) + assert C.password == "PASS" + + # env var must be prefixed + environment["PASSWORD"] = "PASS" + with pytest.raises(ConfigEntryMissingException) as exc_val: + resolve.make_configuration(NamespacedConfiguration()) + environment["DLT_TEST__PASSWORD"] = "PASS" + C = resolve.make_configuration(NamespacedConfiguration()) + assert C.password == "PASS" + + +def test_explicit_namespaces(mock_provider: MockProvider) -> None: + mock_provider.value = "value" + # mock providers separates namespaces with | and key with - + _, k = mock_provider.get_value("key", Any) + assert k == "-key" + _, k = mock_provider.get_value("key", Any, "ns1") + assert k == "ns1-key" + _, k = mock_provider.get_value("key", Any, "ns1", "ns2") + assert k == "ns1|ns2-key" + + # via make configuration + mock_provider.reset_stats() + resolve.make_configuration(SingleValConfiguration()) + assert mock_provider.last_namespace == () + mock_provider.reset_stats() + resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) + # value is returned only on empty namespace + assert mock_provider.last_namespace == () + # always start with more precise namespace + assert mock_provider.last_namespaces == [("ns1",), ()] + mock_provider.reset_stats() + resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1", "ns2")) + assert mock_provider.last_namespaces == [("ns1", "ns2"), ("ns1",), ()] + + +def test_explicit_namespaces_with_namespaced_config(mock_provider: MockProvider) -> None: + mock_provider.value = "value" + # with namespaced config + mock_provider.return_value_on = ("DLT_TEST",) + resolve.make_configuration(NamespacedConfiguration()) + assert mock_provider.last_namespace == ("DLT_TEST",) + # namespace from config is mandatory, provider will not be queried with () + assert mock_provider.last_namespaces == [("DLT_TEST",)] + # namespaced config is always innermost + mock_provider.reset_stats() + resolve.make_configuration(NamespacedConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("ns1", "DLT_TEST"), ("DLT_TEST",)] + mock_provider.reset_stats() + resolve.make_configuration(NamespacedConfiguration(), namespaces=("ns1", "ns2")) + assert mock_provider.last_namespaces == [("ns1", "ns2", "DLT_TEST"), ("ns1", "DLT_TEST"), ("DLT_TEST",)] + + +def test_explicit_namespaces_from_embedded_config(mock_provider: MockProvider) -> None: + mock_provider.value = {"sv": "A"} + C = resolve.make_configuration(EmbeddedConfiguration()) + # we mock the dictionary below as the value for all requests + assert C.sv_config.sv == '{"sv": "A"}' + # following namespaces were used when resolving EmbeddedConfig: () - to resolve sv_config and then: ("sv_config",), () to resolve sv in sv_config + assert mock_provider.last_namespaces == [(), ("sv_config",), ()] + # embedded namespace inner of explicit + mock_provider.reset_stats() + C = resolve.make_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("ns1",), (), ("ns1", "sv_config",), ("ns1",), ()] + + +def test_injected_namespaces(mock_provider: MockProvider) -> None: + container = Container() + mock_provider.value = "value" + + with container.injectable_configuration(resolve.ConfigNamespacesConfiguration(namespaces=("inj-ns1",))): + resolve.make_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("inj-ns1",), ()] + mock_provider.reset_stats() + # explicit namespace inner of injected + resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("inj-ns1", "ns1"), ("inj-ns1",), ()] + # namespaced config inner of injected + mock_provider.reset_stats() + mock_provider.return_value_on = ("DLT_TEST",) + resolve.make_configuration(NamespacedConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("inj-ns1", "ns1", "DLT_TEST"), ("inj-ns1", "DLT_TEST"), ("DLT_TEST",)] + # explicit namespace inner of ns coming from embedded config + mock_provider.reset_stats() + mock_provider.return_value_on = () + mock_provider.value = {"sv": "A"} + resolve.make_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) + # first we look for sv_config -> ("inj-ns1", "ns1"), ("inj-ns1",), () then we look for sv + assert mock_provider.last_namespaces == [("inj-ns1", "ns1"), ("inj-ns1",), (), ("inj-ns1", "ns1", "sv_config"), ("inj-ns1", "ns1"), ("inj-ns1",), ()] + + # multiple injected namespaces + with container.injectable_configuration(resolve.ConfigNamespacesConfiguration(namespaces=("inj-ns1", "inj-ns2"))): + mock_provider.reset_stats() + resolve.make_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("inj-ns1", "inj-ns2"), ("inj-ns1",), ()] + mock_provider.reset_stats() + # explicit namespace inner of injected + resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("inj-ns1", "inj-ns2", "ns1"), ("inj-ns1", "inj-ns2"), ("inj-ns1",), ()] + + +def test_namespace_from_pipeline_name(mock_provider: MockProvider) -> None: + # AXIES__DESTINATION__STORAGE_CREDENTIALS__PRIVATE_KEY, DESTINATION__STORAGE_CREDENTIALS__PRIVATE_KEY, DESTINATION__PRIVATE_KEY, GCP__PRIVATE_KEY + # if pipeline name is present, keys will be looked up twice: with pipeline as top level namespace and without it + + container = Container() + mock_provider.value = "value" + + with container.injectable_configuration(resolve.ConfigNamespacesConfiguration(pipeline_name="PIPE")): + mock_provider.return_value_on = () + resolve.make_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("PIPE",), ()] + + mock_provider.reset_stats() + resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) + # PIPE namespace is exhausted then another lookup without PIPE + assert mock_provider.last_namespaces == [("PIPE", "ns1"), ("PIPE",), ("ns1",), ()] + + mock_provider.return_value_on = ("PIPE", ) + mock_provider.reset_stats() + resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("PIPE", "ns1"), ("PIPE",)] + + # with both pipe and config namespaces are always present in lookup + # "PIPE", "DLT_TEST" + mock_provider.return_value_on = () + mock_provider.reset_stats() + # () will never be searched + with pytest.raises(ConfigEntryMissingException): + resolve.make_configuration(NamespacedConfiguration()) + mock_provider.return_value_on = ("DLT_TEST",) + mock_provider.reset_stats() + resolve.make_configuration(NamespacedConfiguration()) + assert mock_provider.last_namespaces == [("PIPE", "DLT_TEST"), ("DLT_TEST",)] + + # with pipeline and injected namespaces + with container.injectable_configuration(resolve.ConfigNamespacesConfiguration(pipeline_name="PIPE", namespaces=("inj-ns1",))): + mock_provider.return_value_on = () + mock_provider.reset_stats() + resolve.make_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("PIPE", "inj-ns1"), ("PIPE",), ("inj-ns1",), ()] diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index 8fdb5afbb0..30da464a88 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -59,7 +59,7 @@ def test_extract_inner_type() -> None: assert extract_inner_type(str) is str assert extract_inner_type(NewType("NT1", str)) is str assert extract_inner_type(NewType("NT2", NewType("NT3", int))) is int - assert extract_inner_type(Optional[NewType("NT3", bool)]) is bool + assert extract_inner_type(Optional[NewType("NT3", bool)]) is bool # noqa l_1 = Literal[1, 2, 3] assert extract_inner_type(l_1) is int nt_l_2 = NewType("NTL2", float) diff --git a/tests/conftest.py b/tests/conftest.py index 3e5c53ce52..78498e0b64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,25 @@ import os +import dataclasses def pytest_configure(config): # patch the configurations to use test storage by default, we modify the types (classes) fields # the dataclass implementation will use those patched values when creating instances (the values present # in the declaration are not frozen allowing patching) - from dlt.common.configuration.specs import RunConfiguration, LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration + from dlt.common.configuration.specs import normalize_volume_configuration, run_configuration, load_volume_configuration, schema_volume_configuration test_storage_root = "_storage" - RunConfiguration.config_files_storage_path = os.path.join(test_storage_root, "config/%s") - LoadVolumeConfiguration.load_volume_path = os.path.join(test_storage_root, "load") - NormalizeVolumeConfiguration.normalize_volume_path = os.path.join(test_storage_root, "normalize") - SchemaVolumeConfiguration.schema_volume_path = os.path.join(test_storage_root, "schemas") + run_configuration.RunConfiguration.config_files_storage_path = os.path.join(test_storage_root, "config/%s") - assert RunConfiguration.config_files_storage_path == os.path.join(test_storage_root, "config/%s") - assert RunConfiguration().config_files_storage_path == os.path.join(test_storage_root, "config/%s") + load_volume_configuration.LoadVolumeConfiguration.load_volume_path = os.path.join(test_storage_root, "load") + + normalize_volume_configuration.NormalizeVolumeConfiguration.normalize_volume_path = os.path.join(test_storage_root, "normalize") + if hasattr(normalize_volume_configuration.NormalizeVolumeConfiguration, "__init__"): + # delete __init__, otherwise it will not be recreated by dataclass + delattr(normalize_volume_configuration.NormalizeVolumeConfiguration, "__init__") + normalize_volume_configuration.NormalizeVolumeConfiguration = dataclasses.dataclass(normalize_volume_configuration.NormalizeVolumeConfiguration, init=True, repr=False) + + schema_volume_configuration.SchemaVolumeConfiguration.schema_volume_path = os.path.join(test_storage_root, "schemas") + + assert run_configuration.RunConfiguration.config_files_storage_path == os.path.join(test_storage_root, "config/%s") + assert run_configuration.RunConfiguration().config_files_storage_path == os.path.join(test_storage_root, "config/%s") From 0d126d5d77c20afc51c7d36cd592664b58c25402 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 13 Oct 2022 22:38:51 +0200 Subject: [PATCH 31/66] removes legacy extractor --- dlt/extract/extractor_storage.py | 39 ------------------------------ dlt/extract/generator/__init__.py | 0 dlt/extract/generator/extractor.py | 0 3 files changed, 39 deletions(-) delete mode 100644 dlt/extract/extractor_storage.py delete mode 100644 dlt/extract/generator/__init__.py delete mode 100644 dlt/extract/generator/extractor.py diff --git a/dlt/extract/extractor_storage.py b/dlt/extract/extractor_storage.py deleted file mode 100644 index ce7d769c43..0000000000 --- a/dlt/extract/extractor_storage.py +++ /dev/null @@ -1,39 +0,0 @@ -import semver - -from dlt.common.json import json_typed_dumps -from dlt.common.typing import Any -from dlt.common.utils import uniq_id -from dlt.common.storages.file_storage import FileStorage -from dlt.common.storages import VersionedStorage, NormalizeStorage - - -class ExtractorStorageBase(VersionedStorage): - def __init__(self, version: semver.VersionInfo, is_owner: bool, storage: FileStorage, normalize_storage: NormalizeStorage) -> None: - self.normalize_storage = normalize_storage - super().__init__(version, is_owner, storage) - - def create_temp_folder(self) -> str: - tf_name = uniq_id() - self.storage.create_folder(tf_name) - return tf_name - - def save_json(self, name: str, d: Any) -> None: - # saves json using typed encoder - self.storage.save(name, json_typed_dumps(d)) - - def commit_events(self, schema_name: str, processed_file_path: str, dest_file_stem: str, no_processed_events: int, load_id: str, with_delete: bool = True) -> str: - raise NotImplementedError() - # schema name cannot contain underscores - # FileStorage.validate_file_name_component(schema_name) - - # dest_name = NormalizeStorage.build_extracted_file_stem(schema_name, dest_file_stem, no_processed_events, load_id) - # # if no events extracted from tracker, file is not saved - # if no_processed_events > 0: - # # moves file to possibly external storage and place in the dest folder atomically - # self.storage.copy_cross_storage_atomically( - # self.normalize_storage.storage.storage_path, NormalizeStorage.EXTRACTED_FOLDER, processed_file_path, dest_name) - - # if with_delete: - # self.storage.delete(processed_file_path) - - # return dest_name diff --git a/dlt/extract/generator/__init__.py b/dlt/extract/generator/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dlt/extract/generator/extractor.py b/dlt/extract/generator/extractor.py deleted file mode 100644 index e69de29bb2..0000000000 From 47e50294b657af47e56151e3725b7936176e94cb Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 13 Oct 2022 22:42:56 +0200 Subject: [PATCH 32/66] moves pipe, extract and sources code into extract --- dlt/extract/extract.py | 102 ++++++++++ dlt/extract/pipe.py | 449 +++++++++++++++++++++++++++++++++++++++++ dlt/extract/sources.py | 219 ++++++++++++++++++++ 3 files changed, 770 insertions(+) create mode 100644 dlt/extract/extract.py create mode 100644 dlt/extract/pipe.py create mode 100644 dlt/extract/sources.py diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py new file mode 100644 index 0000000000..8acf0e014f --- /dev/null +++ b/dlt/extract/extract.py @@ -0,0 +1,102 @@ +import os +from typing import List + +from dlt.common.utils import uniq_id +from dlt.common.sources import TDirectDataItem, TDataItem +from dlt.common.schema import utils, TSchemaUpdate +from dlt.common.storages import NormalizeStorage, DataItemStorage +from dlt.common.configuration.specs import NormalizeVolumeConfiguration + + +from experiments.pipeline.pipe import PipeIterator +from experiments.pipeline.sources import DltResource, DltSource + + +class ExtractorStorage(DataItemStorage, NormalizeStorage): + EXTRACT_FOLDER = "extract" + + def __init__(self, C: NormalizeVolumeConfiguration) -> None: + # data item storage with jsonl with pua encoding + super().__init__("puae-jsonl", False, C) + self.initialize_storage() + + def initialize_storage(self) -> None: + self.storage.create_folder(ExtractorStorage.EXTRACT_FOLDER, exists_ok=True) + + def create_extract_id(self) -> str: + extract_id = uniq_id() + self.storage.create_folder(self._get_extract_path(extract_id)) + return extract_id + + def commit_extract_files(self, extract_id: str, with_delete: bool = True) -> None: + extract_path = self._get_extract_path(extract_id) + for file in self.storage.list_folder_files(extract_path, to_root=False): + from_file = os.path.join(extract_path, file) + to_file = os.path.join(NormalizeStorage.EXTRACTED_FOLDER, file) + if with_delete: + self.storage.atomic_rename(from_file, to_file) + else: + # create hardlink which will act as a copy + self.storage.link_hard(from_file, to_file) + if with_delete: + self.storage.delete_folder(extract_path, recursively=True) + + def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: + template = NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s") + return self.storage.make_full_path(os.path.join(self._get_extract_path(load_id), template)) + + def _get_extract_path(self, extract_id: str) -> str: + return os.path.join(ExtractorStorage.EXTRACT_FOLDER, extract_id) + + +def extract(source: DltSource, storage: ExtractorStorage) -> TSchemaUpdate: + dynamic_tables: TSchemaUpdate = {} + schema = source.schema + extract_id = storage.create_extract_id() + + def _write_item(table_name: str, item: TDirectDataItem) -> None: + # normalize table name before writing so the name match the name in schema + # note: normalize function should be cached so there's almost no penalty on frequent calling + # note: column schema is not required for jsonl writer used here + # TODO: consider dropping DLT_METADATA_FIELD in all items before writing, this however takes CPU time + # event.pop(DLT_METADATA_FIELD, None) # type: ignore + storage.write_data_item(extract_id, schema.name, schema.normalize_table_name(table_name), item, None) + + def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: + table_name = resource._table_name_hint_fun(item) + existing_table = dynamic_tables.get(table_name) + if existing_table is None: + dynamic_tables[table_name] = [resource.table_schema(item)] + else: + # quick check if deep table merge is required + if resource._table_has_other_dynamic_hints: + new_table = resource.table_schema(item) + # this merges into existing table in place + utils.merge_tables(existing_table[0], new_table) + else: + # if there are no other dynamic hints besides name then we just leave the existing partial table + pass + # write to storage with inferred table name + _write_item(table_name, item) + + # yield from all selected pipes + for pipe_item in PipeIterator.from_pipes(source.pipes): + # get partial table from table template + resource = source.resource_by_pipe(pipe_item.pipe) + if resource._table_name_hint_fun: + if isinstance(pipe_item.item, List): + for item in pipe_item.item: + _write_dynamic_table(resource, item) + else: + _write_dynamic_table(resource, pipe_item.item) + else: + # write item belonging to table with static name + _write_item(resource.name, pipe_item.item) + + # flush all buffered writers + storage.close_writers(extract_id) + storage.commit_extract_files(extract_id) + + # returns set of partial tables + return dynamic_tables + diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py new file mode 100644 index 0000000000..18142b6dbe --- /dev/null +++ b/dlt/extract/pipe.py @@ -0,0 +1,449 @@ +import types +import asyncio +from asyncio import Future +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from threading import Thread +from typing import Optional, Sequence, Union, Callable, Iterable, Iterator, List, NamedTuple, Awaitable, Tuple, Type, TYPE_CHECKING + +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec +from dlt.common.typing import TDataItem +from dlt.common.sources import TDirectDataItem, TResolvableDataItem + +if TYPE_CHECKING: + TItemFuture = Future[TDirectDataItem] +else: + TItemFuture = Future + +from dlt.common.exceptions import DltException +from dlt.common.time import sleep + + +class PipeItem(NamedTuple): + item: TDirectDataItem + step: int + pipe: "Pipe" + + +class ResolvablePipeItem(NamedTuple): + # mypy unable to handle recursive types, ResolvablePipeItem should take itself in "item" + item: Union[TResolvableDataItem, Iterator[TResolvableDataItem]] + step: int + pipe: "Pipe" + + +class FuturePipeItem(NamedTuple): + item: TItemFuture + step: int + pipe: "Pipe" + + +class SourcePipeItem(NamedTuple): + item: Union[Iterator[TResolvableDataItem], Iterator[ResolvablePipeItem]] + step: int + pipe: "Pipe" + + +# pipeline step may be iterator of data items or mapping function that returns data item or another iterator +TPipeStep = Union[ + Iterable[TResolvableDataItem], + Iterator[TResolvableDataItem], + Callable[[TDirectDataItem], TResolvableDataItem], + Callable[[TDirectDataItem], Iterator[TResolvableDataItem]], + Callable[[TDirectDataItem], Iterator[ResolvablePipeItem]] +] + + +class ForkPipe: + def __init__(self, pipe: "Pipe", step: int = -1) -> None: + self._pipes: List[Tuple["Pipe", int]] = [] + self.add_pipe(pipe, step) + + def add_pipe(self, pipe: "Pipe", step: int = -1) -> None: + if pipe not in self._pipes: + self._pipes.append((pipe, step)) + + def has_pipe(self, pipe: "Pipe") -> bool: + return pipe in [p[0] for p in self._pipes] + + def __call__(self, item: TDirectDataItem) -> Iterator[ResolvablePipeItem]: + for i, (pipe, step) in enumerate(self._pipes): + _it = item if i == 0 else deepcopy(item) + # always start at the beginning + yield ResolvablePipeItem(_it, step, pipe) + + +class FilterItem: + def __init__(self, filter_f: Callable[[TDataItem], bool]) -> None: + self._filter_f = filter_f + + def __call__(self, item: TDirectDataItem) -> Optional[TDirectDataItem]: + # item may be a list TDataItem or a single TDataItem + if isinstance(item, list): + item = [i for i in item if self._filter_f(i)] + if not item: + # item was fully consumed by the filter + return None + return item + else: + return item if self._filter_f(item) else None + + +class Pipe: + def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = None) -> None: + self.name = name + self._steps: List[TPipeStep] = steps or [] + self._backup_steps: List[TPipeStep] = None + self._pipe_id = f"{name}_{id(self)}" + self.parent = parent + + @classmethod + def from_iterable(cls, name: str, gen: Union[Iterable[TResolvableDataItem], Iterator[TResolvableDataItem]]) -> "Pipe": + if isinstance(gen, Iterable): + gen = iter(gen) + return cls(name, [gen]) + + @property + def head(self) -> TPipeStep: + return self._steps[0] + + @property + def tail(self) -> TPipeStep: + return self._steps[-1] + + @property + def steps(self) -> List[TPipeStep]: + return self._steps + + def __getitem__(self, i: int) -> TPipeStep: + return self._steps[i] + + def __len__(self) -> int: + return len(self._steps) + + def fork(self, child_pipe: "Pipe", child_step: int = -1) -> "Pipe": + if len(self._steps) == 0: + raise CreatePipeException("Cannot fork to empty pipe") + fork_step = self.tail + if not isinstance(fork_step, ForkPipe): + fork_step = ForkPipe(child_pipe, child_step) + self.add_step(fork_step) + else: + if not fork_step.has_pipe(child_pipe): + fork_step.add_pipe(child_pipe, child_step) + return self + + def clone(self) -> "Pipe": + p = Pipe(self.name, self._steps.copy(), self.parent) + # clone shares the id with the original + p._pipe_id = self._pipe_id + return p + + # def backup(self) -> None: + # if self.has_backup: + # raise PipeBackupException("Pipe backup already exists, restore pipe first") + # self._backup_steps = self._steps.copy() + + # @property + # def has_backup(self) -> bool: + # return self._backup_steps is not None + + + # def restore(self) -> None: + # if not self.has_backup: + # raise PipeBackupException("No pipe backup to restore") + # self._steps = self._backup_steps + # self._backup_steps = None + + def add_step(self, step: TPipeStep) -> "Pipe": + if len(self._steps) == 0 and self.parent is None: + # first element must be iterable or iterator + if not isinstance(step, (Iterable, Iterator)): + raise CreatePipeException("First step of independent pipe must be Iterable or Iterator") + else: + if isinstance(step, Iterable): + step = iter(step) + self._steps.append(step) + else: + if isinstance(step, (Iterable, Iterator)): + if self.parent is not None: + raise CreatePipeException("Iterable or Iterator cannot be a step in dependent pipe") + else: + raise CreatePipeException("Iterable or Iterator can only be a first step in independent pipe") + if not callable(step): + raise CreatePipeException("Pipe step must be a callable taking exactly one data item as input") + self._steps.append(step) + return self + + def full_pipe(self) -> "Pipe": + if self.parent: + pipe = self.parent.full_pipe().steps + else: + pipe = [] + + # return pipe with resolved dependencies + pipe.extend(self._steps) + return Pipe(self.name, pipe) + + def __repr__(self) -> str: + return f"Pipe {self.name} ({self._pipe_id}) at {id(self)}" + + +class PipeIterator(Iterator[PipeItem]): + + @configspec + class PipeIteratorConfiguration: + max_parallel_items: int = 100 + worker_threads: int = 5 + futures_poll_interval: float = 0.01 + + + def __init__(self, max_parallel_items: int, worker_threads, futures_poll_interval: float) -> None: + self.max_parallel_items = max_parallel_items + self.worker_threads = worker_threads + self.futures_poll_interval = futures_poll_interval + + self._async_pool: asyncio.AbstractEventLoop = None + self._async_pool_thread: Thread = None + self._thread_pool: ThreadPoolExecutor = None + self._sources: List[SourcePipeItem] = [] + self._futures: List[FuturePipeItem] = [] + + @classmethod + @with_config(spec=PipeIteratorConfiguration) + def from_pipe(cls, pipe: Pipe, *, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + if pipe.parent: + pipe = pipe.full_pipe() + # head must be iterator + assert isinstance(pipe.head, Iterator) + # create extractor + extract = cls(max_parallelism, worker_threads, futures_poll_interval) + # add as first source + extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) + return extract + + @classmethod + @with_config(spec=PipeIteratorConfiguration) + def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, *, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + extract = cls(max_parallelism, worker_threads, futures_poll_interval) + # clone all pipes before iterating (recursively) as we will fork them and this add steps + pipes = PipeIterator.clone_pipes(pipes) + + def _fork_pipeline(pipe: Pipe) -> None: + if pipe.parent: + # fork the parent pipe + pipe.parent.fork(pipe) + # make the parent yield by sending a clone of item to itself with position at the end + if yield_parents and pipe.parent in pipes: + # fork is last step of the pipe so it will yield + pipe.parent.fork(pipe.parent, len(pipe.parent) - 1) + _fork_pipeline(pipe.parent) + else: + # head of independent pipe must be iterator + assert isinstance(pipe.head, Iterator) + # add every head as source only once + if not any(i.pipe == pipe for i in extract._sources): + print("add to sources: " + pipe.name) + extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) + + + for pipe in reversed(pipes): + _fork_pipeline(pipe) + + return extract + + def __next__(self) -> PipeItem: + pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None + # __next__ should call itself to remove the `while` loop and continue clauses but that may lead to stack overflows: there's no tail recursion opt in python + # https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion (see Y combinator on how it could be emulated) + while True: + # do we need new item? + if pipe_item is None: + # process element from the futures + if len(self._futures) > 0: + pipe_item = self._resolve_futures() + # if none then take element from the newest source + if pipe_item is None: + pipe_item = self._get_source_item() + + if pipe_item is None: + if len(self._futures) == 0 and len(self._sources) == 0: + # no more elements in futures or sources + raise StopIteration() + else: + # if len(_sources + # print("waiting") + sleep(self.futures_poll_interval) + continue + + # if item is iterator, then add it as a new source + if isinstance(pipe_item.item, Iterator): + # print(f"adding iterable {item}") + self._sources.append(SourcePipeItem(pipe_item.item, pipe_item.step, pipe_item.pipe)) + pipe_item = None + continue + + if isinstance(pipe_item.item, Awaitable) or callable(pipe_item.item): + # do we have a free slot or one of the slots is done? + if len(self._futures) < self.max_parallel_items or self._next_future() >= 0: + if isinstance(pipe_item.item, Awaitable): + future = asyncio.run_coroutine_threadsafe(pipe_item.item, self._ensure_async_pool()) + else: + future = self._ensure_thread_pool().submit(pipe_item.item) + # print(future) + self._futures.append(FuturePipeItem(future, pipe_item.step, pipe_item.pipe)) # type: ignore + # pipe item consumed for now, request a new one + pipe_item = None + continue + else: + # print("maximum futures exceeded, waiting") + sleep(self.futures_poll_interval) + # try same item later + continue + + # if we are at the end of the pipe then yield element + # print(pipe_item) + if pipe_item.step == len(pipe_item.pipe) - 1: + # must be resolved + if isinstance(pipe_item.item, (Iterator, Awaitable)) or callable(pipe_item.pipe): + raise PipeItemProcessingError("Pipe item not processed", pipe_item) + # mypy not able to figure out that item was resolved + return pipe_item # type: ignore + + # advance to next step + step = pipe_item.pipe[pipe_item.step + 1] + assert callable(step) + item = step(pipe_item.item) + pipe_item = ResolvablePipeItem(item, pipe_item.step + 1, pipe_item.pipe) # type: ignore + + + def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: + # lazily create async pool is separate thread + if self._async_pool: + return self._async_pool + + def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + self._async_pool = asyncio.new_event_loop() + self._async_pool_thread = Thread(target=start_background_loop, args=(self._async_pool,), daemon=True) + self._async_pool_thread.start() + + # start or return async pool + return self._async_pool + + def _ensure_thread_pool(self) -> ThreadPoolExecutor: + # lazily start or return thread pool + if self._thread_pool: + return self._thread_pool + + self._thread_pool = ThreadPoolExecutor(self.worker_threads) + return self._thread_pool + + def __enter__(self) -> "PipeIterator": + return self + + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType) -> None: + + def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: + loop.stop() + + for f, _, _ in self._futures: + if not f.done(): + f.cancel() + print("stopping loop") + if self._async_pool: + self._async_pool.call_soon_threadsafe(stop_background_loop, self._async_pool) + print("joining thread") + self._async_pool_thread.join() + self._async_pool = None + self._async_pool_thread = None + if self._thread_pool: + self._thread_pool.shutdown(wait=True) + self._thread_pool = None + + def _next_future(self) -> int: + return next((i for i, val in enumerate(self._futures) if val.item.done()), -1) + + def _resolve_futures(self) -> ResolvablePipeItem: + # no futures at all + if len(self._futures) == 0: + return None + + # anything done? + idx = self._next_future() + if idx == -1: + # nothing done + return None + + future, step, pipe = self._futures.pop(idx) + + if future.cancelled(): + # get next future + return self._resolve_futures() + + if future.exception(): + raise future.exception() + + return ResolvablePipeItem(future.result(), step, pipe) + + def _get_source_item(self) -> ResolvablePipeItem: + # no more sources to iterate + if len(self._sources) == 0: + return None + + # get items from last added iterator, this makes the overall Pipe as close to FIFO as possible + gen, step, pipe = self._sources[-1] + try: + item = next(gen) + # full pipe item may be returned, this is used by ForkPipe step + # to redirect execution of an item to another pipe + if isinstance(item, ResolvablePipeItem): + return item + else: + # keep the item assigned step and pipe + return ResolvablePipeItem(item, step, pipe) + except StopIteration: + # remove empty iterator and try another source + self._sources.pop() + return self._get_source_item() + + @staticmethod + def clone_pipes(pipes: Sequence[Pipe]) -> Sequence[Pipe]: + # will clone the pipes including the dependent ones + cloned_pipes = [p.clone() for p in pipes] + cloned_pairs = {id(p): c for p, c in zip(pipes, cloned_pipes)} + + for clone in cloned_pipes: + while True: + if not clone.parent: + break + # if already a clone + if clone.parent in cloned_pairs.values(): + break + # clone if parent pipe not yet cloned + if id(clone.parent) not in cloned_pairs: + print("cloning:" + clone.parent.name) + cloned_pairs[id(clone.parent)] = clone.parent.clone() + # replace with clone + print(f"replace depends on {clone.name} to {clone.parent.name}") + clone.parent = cloned_pairs[id(clone.parent)] + # recurr with clone + clone = clone.parent + + return cloned_pipes + + +class PipeException(DltException): + pass + + +class CreatePipeException(PipeException): + pass + + +class PipeItemProcessingError(PipeException): + pass + diff --git a/dlt/extract/sources.py b/dlt/extract/sources.py new file mode 100644 index 0000000000..4b85646909 --- /dev/null +++ b/dlt/extract/sources.py @@ -0,0 +1,219 @@ +import contextlib +from copy import deepcopy +import inspect +from typing import AsyncIterable, AsyncIterator, Coroutine, Dict, Generator, Iterable, Iterator, List, Set, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any +from dlt.common.exceptions import DltException +from dlt.common.schema.utils import new_table + +from dlt.common.typing import TDataItem +from dlt.common.sources import TFunDataItemDynHint, TDirectDataItem +from dlt.common.schema.schema import Schema +from dlt.common.schema.typing import TPartialTableSchema, TTableSchema, TTableSchemaColumns, TWriteDisposition + +from experiments.pipeline.pipe import FilterItem, Pipe, CreatePipeException, PipeIterator + + +class TTableSchemaTemplate(TypedDict, total=False): + name: Union[str, TFunDataItemDynHint] + description: Union[str, TFunDataItemDynHint] + write_disposition: Union[TWriteDisposition, TFunDataItemDynHint] + # table_sealed: Optional[bool] + parent: Union[str, TFunDataItemDynHint] + columns: Union[TTableSchemaColumns, TFunDataItemDynHint] + + +class DltResourceSchema: + def __init__(self, name: str, table_schema_template: TTableSchemaTemplate = None): + # self.__name__ = name + self.name = name + self._table_name_hint_fun: TFunDataItemDynHint = None + self._table_has_other_dynamic_hints: bool = False + self._table_schema_template: TTableSchemaTemplate = None + self._table_schema: TPartialTableSchema = None + if table_schema_template: + self._set_template(table_schema_template) + + def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: + + if not self._table_schema_template: + # if table template is not present, generate partial table from name + if not self._table_schema: + self._table_schema = new_table(self.name) + return self._table_schema + + def _resolve_hint(hint: Union[Any, TFunDataItemDynHint]) -> Any: + if callable(hint): + return hint(item) + else: + return hint + + # if table template present and has dynamic hints, the data item must be provided + if self._table_name_hint_fun: + if item is None: + raise DataItemRequiredForDynamicTableHints(self.name) + else: + cloned_template = deepcopy(self._table_schema_template) + return cast(TPartialTableSchema, {k: _resolve_hint(v) for k, v in cloned_template.items()}) + else: + return cast(TPartialTableSchema, self._table_schema_template) + + def _set_template(self, table_schema_template: TTableSchemaTemplate) -> None: + # if "name" is callable in the template then the table schema requires actual data item to be inferred + name_hint = table_schema_template.get("name") + if callable(name_hint): + self._table_name_hint_fun = name_hint + # check if any other hints in the table template should be inferred from data + self._table_has_other_dynamic_hints = any(callable(v) for k, v in table_schema_template.items() if k != "name") + + if self._table_has_other_dynamic_hints and not self._table_name_hint_fun: + raise InvalidTableSchemaTemplate("Table name must be a function if any other table hint is a function") + self._table_schema_template = table_schema_template + + +class DltResource(Iterable[TDirectDataItem], DltResourceSchema): + def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate): + self.name = pipe.name + self._pipe = pipe + super().__init__(self.name, table_schema_template) + + @classmethod + def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSchemaTemplate = None) -> "DltResource": + # call functions assuming that they do not take any parameters, typically they are generator functions + if callable(data): + data = data() + + if isinstance(data, DltResource): + return data + + if isinstance(data, Pipe): + return cls(data, table_schema_template) + + # several iterable types are not allowed and must be excluded right away + if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): + raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) + + # create resource from iterator or iterable + if isinstance(data, (Iterable, Iterator)): + if inspect.isgenerator(data): + name = name or data.__name__ + else: + name = name or None + if not name: + raise ResourceNameRequired("The DltResource name was not provide or could not be inferred.") + pipe = Pipe.from_iterable(name, data) + return cls(pipe, table_schema_template) + + # some other data type that is not supported + raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) + + + def select(self, *table_names: Iterable[str]) -> "DltResource": + if not self._table_name_hint_fun: + raise CreatePipeException("Table name is not dynamic, table selection impossible") + + def _filter(item: TDataItem) -> bool: + return self._table_name_hint_fun(item) in table_names + + # add filtering function at the end of pipe + self._pipe.add_step(FilterItem(_filter)) + return self + + def map(self) -> None: + raise NotImplementedError() + + def flat_map(self) -> None: + raise NotImplementedError() + + def filter(self) -> None: + raise NotImplementedError() + + def __iter__(self) -> Iterator[TDirectDataItem]: + return map(lambda item: item.item, PipeIterator.from_pipe(self._pipe)) + + def __repr__(self) -> str: + return f"DltResource {self.name} ({self._pipe._pipe_id}) at {id(self)}" + + +class DltSource(Iterable[TDirectDataItem]): + def __init__(self, schema: Schema, resources: Sequence[DltResource] = None) -> None: + self.name = schema.name + self._schema = schema + self._resources: List[DltResource] = list(resources or []) + self._enabled_resource_names: Set[str] = set(r.name for r in self._resources) + + @classmethod + def from_data(cls, schema: Schema, data: Any) -> "DltSource": + # creates source from various forms of data + if isinstance(data, DltSource): + return data + + # several iterable types are not allowed and must be excluded right away + if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): + raise InvalidSourceDataType("Invalid data type for DltSource", type(data)) + + # in case of sequence, enumerate items and convert them into resources + if isinstance(data, Sequence): + resources = [DltResource.from_data(i) for i in data] + else: + resources = [DltResource.from_data(data)] + + return cls(schema, resources) + + + def __getitem__(self, name: str) -> List[DltResource]: + if name not in self._enabled_resource_names: + raise KeyError(name) + return [r for r in self._resources if r.name == name] + + def resource_by_pipe(self, pipe: Pipe) -> DltResource: + # identify pipes by memory pointer + return next(r for r in self._resources if r._pipe._pipe_id is pipe._pipe_id) + + @property + def resources(self) -> Sequence[DltResource]: + return [r for r in self._resources if r.name in self._enabled_resource_names] + + @property + def pipes(self) -> Sequence[Pipe]: + return [r._pipe for r in self._resources if r.name in self._enabled_resource_names] + + @property + def schema(self) -> Schema: + return self._schema + + def discover_schema(self) -> Schema: + # extract tables from all resources and update internal schema + for r in self._resources: + # names must be normalized here + with contextlib.suppress(DataItemRequiredForDynamicTableHints): + partial_table = self._schema.normalize_table_identifiers(r.table_schema()) + self._schema.update_schema(partial_table) + return self._schema + + def select(self, *resource_names: str) -> "DltSource": + # make sure all selected resources exist + for name in resource_names: + self.__getitem__(name) + self._enabled_resource_names = set(resource_names) + return self + + + def __iter__(self) -> Iterator[TDirectDataItem]: + return map(lambda item: item.item, PipeIterator.from_pipes(self.pipes)) + + def __repr__(self) -> str: + return f"DltSource {self.name} at {id(self)}" + + +class DltSourceException(DltException): + pass + + +class DataItemRequiredForDynamicTableHints(DltException): + def __init__(self, resource_name: str) -> None: + self.resource_name = resource_name + super().__init__(f"Instance of Data Item required to generate table schema in resource {resource_name}") + + + +# class From 5c76bcb0ecd120c6ce2a330b8a6f72c03758f164 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 14 Oct 2022 12:27:28 +0200 Subject: [PATCH 33/66] adds config injection to normalize, passes instantiated configs to workers --- dlt/normalize/__init__.py | 3 +- dlt/normalize/configuration.py | 12 +-- dlt/normalize/normalize.py | 82 ++++++++++---------- tests/normalize/mock_rasa_json_normalizer.py | 1 + tests/normalize/test_normalize.py | 32 ++++---- 5 files changed, 61 insertions(+), 69 deletions(-) diff --git a/dlt/normalize/__init__.py b/dlt/normalize/__init__.py index a55a9257f8..a40a5eaa7e 100644 --- a/dlt/normalize/__init__.py +++ b/dlt/normalize/__init__.py @@ -1,2 +1 @@ -from dlt._version import normalize_version as __version__ -from .normalize import Normalize, configuration \ No newline at end of file +from .normalize import Normalize \ No newline at end of file diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index 045aaadc2a..1a924520d6 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -1,16 +1,10 @@ from dlt.common.typing import StrAny from dlt.common.data_writers import TLoaderFileFormat from dlt.common.configuration import make_configuration, configspec -from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType, NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration +from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType -from . import __version__ - -@configspec -class NormalizeConfiguration(PoolRunnerConfiguration, NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration): +@configspec(init=True) +class NormalizeConfiguration(PoolRunnerConfiguration): loader_file_format: TLoaderFileFormat = "jsonl" # jsonp or insert commands will be generated pool_type: TPoolType = "process" - - -def configuration(initial_values: StrAny = None) -> NormalizeConfiguration: - return make_configuration(NormalizeConfiguration(), initial_value=initial_values) diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 89b3afd0e6..0d85b97e85 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -4,24 +4,27 @@ from prometheus_client import Counter, CollectorRegistry, REGISTRY, Gauge from dlt.common import pendulum, signals, json, logger +from dlt.common.configuration import with_config +from dlt.common.configuration.specs.load_volume_configuration import LoadVolumeConfiguration +from dlt.common.configuration.specs.normalize_volume_configuration import NormalizeVolumeConfiguration +from dlt.common.data_writers.writers import TLoaderFileFormat from dlt.common.json import custom_pua_decode -from dlt.cli import TRunnerArgs -from dlt.common.runners import TRunMetrics, Runnable, run_pool, initialize_runner -from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.runners import TRunMetrics, Runnable +from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages import NormalizeStorage, SchemaStorage, LoadStorage from dlt.common.telemetry import get_logging_extras -from dlt.common.typing import StrAny, TDataItem +from dlt.common.typing import ConfigValue, StrAny, TDataItem from dlt.common.exceptions import PoolException from dlt.common.schema import TSchemaUpdate, Schema from dlt.common.schema.exceptions import CannotCoerceColumnException -from dlt.normalize.configuration import configuration, NormalizeConfiguration +from dlt.normalize.configuration import NormalizeConfiguration # normalize worker wrapping function (map_parallel, map_single) return type TMapFuncRV = Tuple[int, List[TSchemaUpdate], List[Sequence[str]]] # (total items processed, list of schema updates, list of processed files) # normalize worker wrapping function signature -TMapFuncType = Callable[[str, str, Sequence[str]], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process) +TMapFuncType = Callable[[Schema, str, Sequence[str]], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process) class Normalize(Runnable[ProcessPool]): @@ -32,8 +35,9 @@ class Normalize(Runnable[ProcessPool]): schema_version_gauge: Gauge = None load_package_counter: Counter = None - def __init__(self, C: NormalizeConfiguration, collector: CollectorRegistry = REGISTRY, schema_storage: SchemaStorage = None) -> None: - self.CONFIG = C + @with_config(spec=NormalizeConfiguration, namespaces=("normalize",)) + def __init__(self, config: NormalizeConfiguration = ConfigValue, collector: CollectorRegistry = REGISTRY, schema_storage: SchemaStorage = None) -> None: + self.config = config self.pool: ProcessPool = None self.normalize_storage: NormalizeStorage = None self.load_storage: LoadStorage = None @@ -42,7 +46,7 @@ def __init__(self, C: NormalizeConfiguration, collector: CollectorRegistry = REG # setup storages self.create_storages() # create schema storage with give type - self.schema_storage = schema_storage or SchemaStorage(self.CONFIG, makedirs=True) + self.schema_storage = schema_storage or SchemaStorage(makedirs=True) try: self.create_gauges(collector) except ValueError as v: @@ -58,9 +62,9 @@ def create_gauges(registry: CollectorRegistry) -> None: Normalize.load_package_counter = Gauge("normalize_load_packages_created_count", "Count of load package created", ["schema"], registry=registry) def create_storages(self) -> None: - self.normalize_storage = NormalizeStorage(True, self.CONFIG) + self.normalize_storage = NormalizeStorage(True) # normalize saves in preferred format but can read all supported formats - self.load_storage = LoadStorage(True, self.CONFIG, self.CONFIG.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + self.load_storage = LoadStorage(True, self.config.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) @staticmethod @@ -74,10 +78,17 @@ def load_or_create_schema(schema_storage: SchemaStorage, schema_name: str) -> Sc return schema @staticmethod - def w_normalize_files(CONFIG: NormalizeConfiguration, schema_name: str, load_id: str, extracted_items_files: Sequence[str]) -> Tuple[TSchemaUpdate, int]: - schema = Normalize.load_or_create_schema(SchemaStorage(CONFIG, makedirs=False), schema_name) - load_storage = LoadStorage(False, CONFIG, CONFIG.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) - normalize_storage = NormalizeStorage(False, CONFIG) + def w_normalize_files( + normalize_storage_config: NormalizeVolumeConfiguration, + loader_storage_config: LoadVolumeConfiguration, + loader_file_format: TLoaderFileFormat, + stored_schema: TStoredSchema, + load_id: str, + extracted_items_files: Sequence[str] + ) -> Tuple[TSchemaUpdate, int]: + schema = Schema.from_stored_schema(stored_schema) + load_storage = LoadStorage(False, loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, loader_storage_config) + normalize_storage = NormalizeStorage(False, normalize_storage_config) schema_update: TSchemaUpdate = {} total_items = 0 @@ -87,7 +98,7 @@ def w_normalize_files(CONFIG: NormalizeConfiguration, schema_name: str, load_id: for extracted_items_file in extracted_items_files: line_no: int = 0 root_table_name = NormalizeStorage.parse_normalize_file_name(extracted_items_file).table_name - logger.debug(f"Processing extracted items in {extracted_items_file} in load_id {load_id} with table name {root_table_name} and schema {schema_name}") + logger.debug(f"Processing extracted items in {extracted_items_file} in load_id {load_id} with table name {root_table_name} and schema {schema.name}") with normalize_storage.storage.open_file(extracted_items_file) as f: # enumerate jsonl file line by line for line_no, line in enumerate(f): @@ -144,15 +155,24 @@ def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, items_count += 1 return schema_update, items_count - def map_parallel(self, schema_name: str, load_id: str, files: Sequence[str]) -> TMapFuncRV: + def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV: # TODO: maybe we should chunk by file size, now map all files to workers chunk_files = [files] - param_chunk = [(self.CONFIG, schema_name, load_id, files) for files in chunk_files] + schema_dict = schema.to_dict() + config_tuple = (self.normalize_storage.config, self.load_storage.config, self.config.loader_file_format, schema_dict) + param_chunk = [(*config_tuple, load_id, files) for files in chunk_files] processed_chunks = self.pool.starmap(Normalize.w_normalize_files, param_chunk) return sum([t[1] for t in processed_chunks]), [t[0] for t in processed_chunks], chunk_files - def map_single(self, schema_name: str, load_id: str, files: Sequence[str]) -> TMapFuncRV: - processed_chunk = Normalize.w_normalize_files(self.CONFIG, schema_name, load_id, files) + def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV: + processed_chunk = Normalize.w_normalize_files( + self.normalize_storage.config, + self.load_storage.config, + self.config.loader_file_format, + schema.to_dict(), + load_id, + files + ) return processed_chunk[1], [processed_chunk[0]], [files] def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> int: @@ -166,10 +186,11 @@ def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> return updates_count def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files: Sequence[str]) -> None: + schema = Normalize.load_or_create_schema(self.schema_storage, schema_name) + # process files in parallel or in single thread, depending on map_f - total_items, schema_updates, chunk_files = map_f(schema_name, load_id, files) + total_items, schema_updates, chunk_files = map_f(schema, load_id, files) - schema = Normalize.load_or_create_schema(self.schema_storage, schema_name) # gather schema from all manifests, validate consistency and combine updates_count = self.update_schema(schema, schema_updates) self.schema_version_gauge.labels(schema_name).set(schema.version) @@ -233,20 +254,3 @@ def run(self, pool: ProcessPool) -> TRunMetrics: self.spool_schema_files(schema_name, list(files_in_schema)) # return info on still pending files (if extractor saved something in the meantime) return TRunMetrics(False, False, len(self.normalize_storage.list_files_to_normalize_sorted())) - - -def main(args: TRunnerArgs) -> int: - # initialize runner - C = configuration(args._asdict()) - initialize_runner(C) - # create objects and gauges - try: - n = Normalize(C, REGISTRY) - except Exception: - logger.exception("init module") - return -1 - return run_pool(C, n) - - -def run_main(args: TRunnerArgs) -> None: - exit(main(args)) diff --git a/tests/normalize/mock_rasa_json_normalizer.py b/tests/normalize/mock_rasa_json_normalizer.py index 3975c484b9..e516e7527a 100644 --- a/tests/normalize/mock_rasa_json_normalizer.py +++ b/tests/normalize/mock_rasa_json_normalizer.py @@ -5,6 +5,7 @@ def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: + print(f"CUSTOM NORM: {schema.name} {table_name}") if schema.name == "event": # this emulates rasa parser on standard parser event = {"sender_id": source_event["sender_id"], "timestamp": source_event["timestamp"], "type": source_event["event"]} diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 4830dd32ea..883a8b44dc 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Sequence import pytest from fnmatch import fnmatch +from typing import Dict, List, Sequence from prometheus_client import CollectorRegistry from multiprocessing import get_start_method, Pool from multiprocessing.dummy import Pool as ThreadPool @@ -11,11 +11,11 @@ from dlt.common.schema import TDataType from dlt.common.storages import NormalizeStorage, LoadStorage -from experiments.pipeline.extract import ExtractorStorage -from dlt.normalize import Normalize, configuration as normalize_configuration, __version__ +from dlt.extract.extract import ExtractorStorage +from dlt.normalize import Normalize from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES -from tests.utils import TEST_STORAGE_ROOT, assert_no_dict_key_starts_with, write_version, clean_test_storage, init_logger +from tests.utils import TEST_STORAGE_ROOT, TEST_DICT_CONFIG_PROVIDER, assert_no_dict_key_starts_with, write_version, clean_test_storage, init_logger from tests.normalize.utils import json_case_path @@ -35,12 +35,10 @@ def rasa_normalize() -> Normalize: def init_normalize(default_schemas_path: str = None) -> Normalize: clean_test_storage() - initial = {} - if default_schemas_path: - initial = {"import_schema_path": default_schemas_path, "external_schema_format": "json"} - n = Normalize(normalize_configuration(initial), CollectorRegistry()) + with TEST_DICT_CONFIG_PROVIDER.values({"import_schema_path": default_schemas_path, "external_schema_format": "json"}): + n = Normalize(collector=CollectorRegistry()) # set jsonl as default writer - n.load_storage.loader_file_format = n.CONFIG.loader_file_format = "jsonl" + n.load_storage.loader_file_format = n.config.loader_file_format = "jsonl" return n @@ -75,7 +73,7 @@ def test_normalize_single_user_event_jsonl(raw_normalize: Normalize) -> None: def test_normalize_single_user_event_insert(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.loader_file_format = raw_normalize.CONFIG.loader_file_format = "insert_values" + raw_normalize.load_storage.loader_file_format = raw_normalize.config.loader_file_format = "insert_values" expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) # verify values line for expected_table in expected_tables: @@ -131,7 +129,7 @@ def test_preserve_slot_complex_value_json_l(rasa_normalize: Normalize) -> None: def test_preserve_slot_complex_value_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.loader_file_format = "insert_values" + rasa_normalize.load_storage.loader_file_format = rasa_normalize.config.loader_file_format = "insert_values" load_id = normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_slot"], 2) @@ -154,7 +152,7 @@ def test_normalize_raw_type_hints(rasa_normalize: Normalize) -> None: def test_normalize_many_events_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.loader_file_format = "insert_values" + rasa_normalize.load_storage.loader_file_format = rasa_normalize.config.loader_file_format = "insert_values" load_id = normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) @@ -177,7 +175,7 @@ def test_normalize_many_events(rasa_normalize: Normalize) -> None: def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.CONFIG.loader_file_format = "insert_values" + rasa_normalize.load_storage.loader_file_format = rasa_normalize.config.loader_file_format = "insert_values" extract_cases( rasa_normalize.normalize_storage, ["event.event.many_load_2", "event.event.user_load_1", "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2"] @@ -206,7 +204,7 @@ def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: def test_normalize_typed_json(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.loader_file_format = raw_normalize.CONFIG.loader_file_format = "jsonl" + raw_normalize.load_storage.loader_file_format = raw_normalize.config.loader_file_format = "jsonl" extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], "special", "special") raw_normalize.run(ThreadPool(processes=1)) loads = raw_normalize.load_storage.list_packages() @@ -233,7 +231,7 @@ def test_normalize_typed_json(raw_normalize: Normalize) -> None: def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str, table_name: str) -> None: - extractor = ExtractorStorage(normalize_storage.CONFIG) + extractor = ExtractorStorage(normalize_storage.config) extract_id = extractor.create_extract_id() extractor.write_data_item(extract_id, schema_name, table_name, items, None) extractor.close_writers(extract_id) @@ -293,7 +291,3 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType) event_schema = load_storage.load_package_schema(loads[0]) # in raw normalize timestamp column must not be coerced to timestamp assert event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type - - -def test_version() -> None: - assert normalize_configuration()._version == __version__ From afcabaeae7e6e5a4a00ecfd8e83e24640427f9c3 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 14 Oct 2022 12:28:57 +0200 Subject: [PATCH 34/66] removes per component version, uses package version in logging --- Makefile | 8 ++------ dlt/__init__.py | 2 +- dlt/_version.py | 3 --- dlt/common/__init__.py | 1 - dlt/common/logger.py | 15 +++++++++------ dlt/dbt_runner/__init__.py | 1 - dlt/dbt_runner/_version.py | 1 - dlt/load/__init__.py | 1 - tests/common/test_logging.py | 23 +++++++++-------------- tests/tools/create_storages.py | 6 +++--- 10 files changed, 24 insertions(+), 37 deletions(-) delete mode 100644 dlt/_version.py delete mode 100644 dlt/dbt_runner/_version.py diff --git a/Makefile b/Makefile index d5108293b8..0d19ae3f6f 100644 --- a/Makefile +++ b/Makefile @@ -13,15 +13,11 @@ VERSION := ${AUTV}${VERSION_SUFFIX} VERSION_MM := ${AUTVMINMAJ}${VERSION_SUFFIX} -# dbt runner version info -DBT_AUTV=$(shell python3 -c "from dlt.dbt_runner._version import __version__;print(__version__)") -DBT_AUTVMINMAJ=$(shell python3 -c "from dlt.dbt_runner._version import __version__;print('.'.join(__version__.split('.')[:-1]))") - DBT_NAME := scalevector/dlt-dbt-runner DBT_IMG := ${DBT_NAME}:${TAG} DBT_LATEST := ${DBT_NAME}:latest${VERSION_SUFFIX} -DBT_VERSION := ${DBT_AUTV}${VERSION_SUFFIX} -DBT_VERSION_MM := ${DBT_AUTVMINMAJ}${VERSION_SUFFIX} +DBT_VERSION := ${AUTV}${VERSION_SUFFIX} +DBT_VERSION_MM := ${AUTVMINMAJ}${VERSION_SUFFIX} install-poetry: ifneq ($(VIRTUAL_ENV),) diff --git a/dlt/__init__.py b/dlt/__init__.py index b53f5a0b6b..a68927d6ca 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -1 +1 @@ -from dlt._version import common_version as __version__ \ No newline at end of file +__version__ = "0.1.0" \ No newline at end of file diff --git a/dlt/_version.py b/dlt/_version.py deleted file mode 100644 index ddd0d93607..0000000000 --- a/dlt/_version.py +++ /dev/null @@ -1,3 +0,0 @@ -common_version = "0.1.0" -loader_version = "0.1.0" -normalize_version = "0.1.0" diff --git a/dlt/common/__init__.py b/dlt/common/__init__.py index 6da3ee3a0e..7a72b56a9b 100644 --- a/dlt/common/__init__.py +++ b/dlt/common/__init__.py @@ -3,4 +3,3 @@ from .pendulum import pendulum # noqa: F401 from .json import json # noqa: F401, I251 from .time import sleep # noqa: F401 -from dlt._version import common_version as __version__ diff --git a/dlt/common/logger.py b/dlt/common/logger.py index eecf86ede6..9a2094a871 100644 --- a/dlt/common/logger.py +++ b/dlt/common/logger.py @@ -2,6 +2,7 @@ import json_logging import traceback import sentry_sdk +from importlib.metadata import version as pkg_version, PackageNotFoundError from sentry_sdk.transport import HttpTransport from sentry_sdk.integrations.logging import LoggingIntegration from logging import LogRecord, Logger @@ -12,7 +13,7 @@ from dlt.common.configuration.specs import RunConfiguration from dlt.common.utils import filter_env_vars -from dlt._version import common_version as __version__ +from dlt import __version__ DLT_LOGGER_NAME = "dlt" LOGGER: Logger = None @@ -129,10 +130,12 @@ def wrapper(msg: str, *args: Any, **kwargs: Any) -> None: def _extract_version_info(config: RunConfiguration) -> StrStr: - version_info = {"version": __version__, "component_name": config.pipeline_name} - version = getattr(config, "_version", None) - if version: - version_info["component_version"] = version + try: + version = pkg_version("python-dlt") + except PackageNotFoundError: + # if there's no package context, take the version from the code + version = __version__ + version_info = {"dlt_version": version, "pipeline_name": config.pipeline_name} # extract envs with build info version_info.update(filter_env_vars(["COMMIT_SHA", "IMAGE_VERSION"])) return version_info @@ -162,7 +165,7 @@ def _get_sentry_log_level(C: RunConfiguration) -> LoggingIntegration: def _init_sentry(C: RunConfiguration, version: StrStr) -> None: - sys_ver = version["version"] + sys_ver = version["dlt_version"] release = sys_ver + "_" + version.get("commit_sha", "") _SentryHttpTransport.timeout = C.request_timeout[0] # TODO: ignore certain loggers ie. dbt loggers diff --git a/dlt/dbt_runner/__init__.py b/dlt/dbt_runner/__init__.py index 7df9f7aa35..e69de29bb2 100644 --- a/dlt/dbt_runner/__init__.py +++ b/dlt/dbt_runner/__init__.py @@ -1 +0,0 @@ -from ._version import __version__ \ No newline at end of file diff --git a/dlt/dbt_runner/_version.py b/dlt/dbt_runner/_version.py deleted file mode 100644 index 3dc1f76bc6..0000000000 --- a/dlt/dbt_runner/_version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.1.0" diff --git a/dlt/load/__init__.py b/dlt/load/__init__.py index 28501cffe5..0a6c97ed3d 100644 --- a/dlt/load/__init__.py +++ b/dlt/load/__init__.py @@ -1,2 +1 @@ -from dlt._version import loader_version as __version__ from dlt.load.load import Load diff --git a/tests/common/test_logging.py b/tests/common/test_logging.py index 8e3a4bcd42..932b6876a2 100644 --- a/tests/common/test_logging.py +++ b/tests/common/test_logging.py @@ -2,8 +2,9 @@ import logging import json_logging from os import environ +from importlib.metadata import version as pkg_version -from dlt import __version__ as auto_version +from dlt import __version__ as code_version from dlt.common import logger, sleep from dlt.common.typing import StrStr from dlt.common.configuration import configspec @@ -18,12 +19,7 @@ class PureBasicConfiguration(RunConfiguration): @configspec -class PureBasicConfigurationProc(PureBasicConfiguration): - _version: str = "1.6.6" - - -@configspec -class JsonLoggerConfiguration(PureBasicConfigurationProc): +class JsonLoggerConfiguration(PureBasicConfiguration): log_format: str = "JSON" @@ -46,14 +42,13 @@ def environment() -> StrStr: def test_version_extract(environment: StrStr) -> None: version = logger._extract_version_info(PureBasicConfiguration()) - # if component ver not avail use system version - assert version == {'version': auto_version, 'component_name': 'logger'} - version = logger._extract_version_info(PureBasicConfigurationProc()) - assert version["component_version"] == PureBasicConfigurationProc()._version + assert version["dlt_version"].startswith(code_version) + lib_version = pkg_version("python-dlt") + assert version == {'dlt_version': lib_version, 'pipeline_name': 'logger'} # mock image info available in container _mock_image_env(environment) - version = logger._extract_version_info(PureBasicConfigurationProc()) - assert version == {'version': auto_version, 'commit_sha': '192891', 'component_name': 'logger', 'component_version': '1.6.6', 'image_version': 'scale/v:112'} + version = logger._extract_version_info(PureBasicConfiguration()) + assert version == {'dlt_version': lib_version, 'commit_sha': '192891', 'pipeline_name': 'logger', 'image_version': 'scale/v:112'} def test_pod_info_extract(environment: StrStr) -> None: @@ -68,7 +63,7 @@ def test_pod_info_extract(environment: StrStr) -> None: def test_text_logger_init(environment: StrStr) -> None: _mock_image_env(environment) _mock_pod_env(environment) - logger.init_logging_from_config(PureBasicConfigurationProc()) + logger.init_logging_from_config(PureBasicConfiguration()) logger.health("HEALTH data", extra={"metrics": "props"}) logger.metrics("METRICS data", extra={"metrics": "props"}) logger.warning("Warning message here") diff --git a/tests/tools/create_storages.py b/tests/tools/create_storages.py index e3e1f98865..efbc8cf8ff 100644 --- a/tests/tools/create_storages.py +++ b/tests/tools/create_storages.py @@ -2,6 +2,6 @@ from dlt.common.configuration.specs import NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration -NormalizeStorage(True, NormalizeVolumeConfiguration) -LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) -SchemaStorage(SchemaVolumeConfiguration, makedirs=True) +# NormalizeStorage(True, NormalizeVolumeConfiguration) +# LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) +# SchemaStorage(SchemaVolumeConfiguration, makedirs=True) From 4e4645c6900b4ce09f765317264ebfd9e44293c5 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 14 Oct 2022 12:29:27 +0200 Subject: [PATCH 35/66] applies config injection to storages, adds dictionary config provider for testing --- dlt/common/configuration/__init__.py | 2 +- dlt/common/configuration/inject.py | 52 +- .../configuration/providers/__init__.py | 3 +- .../configuration/providers/dictionary.py | 44 ++ dlt/common/configuration/providers/toml.py | 23 + dlt/common/configuration/resolve.py | 18 +- .../specs/schema_volume_configuration.py | 2 +- dlt/common/schema/schema.py | 3 + dlt/common/storages/file_storage.py | 9 - dlt/common/storages/live_schema_storage.py | 8 +- dlt/common/storages/load_storage.py | 24 +- dlt/common/storages/normalize_storage.py | 19 +- dlt/common/storages/schema_storage.py | 16 +- dlt/common/typing.py | 1 + dlt/dbt_runner/configuration.py | 2 - dlt/extract/extract.py | 4 +- dlt/extract/sources.py | 2 +- dlt/load/configuration.py | 2 - dlt/load/load.py | 13 +- dlt/pipeline/pipeline.py | 20 +- experiments/pipeline/__init__.py | 31 ++ experiments/pipeline/configuration.py | 46 +- experiments/pipeline/extract.py | 102 ---- experiments/pipeline/pipe.py | 449 ------------------ experiments/pipeline/pipeline.py | 100 +--- experiments/pipeline/sources.py | 219 --------- .../configuration/test_configuration.py | 25 - tests/common/configuration/test_inject.py | 24 + .../common/storages/test_normalize_storage.py | 8 +- tests/common/storages/test_schema_storage.py | 10 +- tests/conftest.py | 10 +- tests/load/test_dummy_client.py | 12 +- tests/utils.py | 19 +- 33 files changed, 305 insertions(+), 1017 deletions(-) create mode 100644 dlt/common/configuration/providers/dictionary.py create mode 100644 dlt/common/configuration/providers/toml.py delete mode 100644 experiments/pipeline/extract.py delete mode 100644 experiments/pipeline/pipe.py delete mode 100644 experiments/pipeline/sources.py diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 212d7b0575..65939a43c6 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,6 +1,6 @@ from .specs.base_configuration import configspec, is_valid_hint # noqa: F401 from .resolve import make_configuration # noqa: F401 -from .inject import with_config +from .inject import with_config, last_config from .exceptions import ( # noqa: F401 ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, ConfigIntegrityException, ConfigFileNotFoundException) diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index fe48c82de6..dffeba2257 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -2,7 +2,7 @@ import inspect from makefun import wraps from types import ModuleType -from typing import Callable, List, Dict, Type, Any, Optional, Tuple, overload +from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload from inspect import Signature, Parameter from dlt.common.typing import StrAny, TFun, AnyFun @@ -11,6 +11,8 @@ # [^.^_]+ splits by . or _ _SLEEPING_CAT_SPLIT = re.compile("[^.^_]+") +_LAST_DLT_CONFIG = "_last_dlt_config" +TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) @overload @@ -33,44 +35,54 @@ def with_config(func: Optional[AnyFun] = None, /, spec: Type[BaseConfiguration] def decorator(f: TFun) -> TFun: SPEC: Type[BaseConfiguration] = None sig: Signature = inspect.signature(f) - kwargs_par = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None) + kwargs_arg = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None) + spec_arg: Parameter = None if spec is None: SPEC = _spec_from_signature(_get_spec_name_from_f(f), inspect.getmodule(f), sig, only_kw) else: SPEC = spec - # for all positional parameters that do not have default value, set default for p in sig.parameters.values(): + # for all positional parameters that do not have default value, set default if hasattr(SPEC, p.name) and p.default == Parameter.empty: p._default = None # type: ignore + if p.annotation is SPEC: + # if any argument has type SPEC then us it to take initial value + spec_arg = p @wraps(f, new_sig=sig) def _wrap(*args: Any, **kwargs: Any) -> Any: - # for calls providing all parameters to the func, configuration may not be resolved - # if len(args) + len(kwargs) == len(sig.parameters): - # return f(*args, **kwargs) - # bind parameters to signature bound_args = sig.bind_partial(*args, **kwargs) bound_args.apply_defaults() - # if namespace derivation function was provided then call it - nonlocal namespaces - if namespace_f: - namespaces = (namespace_f(bound_args.arguments), ) - # namespaces may be a string - if isinstance(namespaces, str): - namespaces = (namespaces,) - # resolve SPEC - config = make_configuration(SPEC(), namespaces=namespaces, initial_value=bound_args.arguments) + # for calls containing resolved spec in the kwargs, we do not need to resolve again + config: BaseConfiguration = None + if _LAST_DLT_CONFIG in kwargs: + config = last_config(**kwargs) + else: + # if namespace derivation function was provided then call it + nonlocal namespaces + if namespace_f: + namespaces = (namespace_f(bound_args.arguments), ) + # namespaces may be a string + if isinstance(namespaces, str): + namespaces = (namespaces,) + # resolve SPEC + if spec_arg: + config = bound_args.arguments.get(spec_arg.name, None) + config = make_configuration(config or SPEC(), namespaces=namespaces, initial_value=bound_args.arguments) resolved_params = dict(config) # overwrite or add resolved params for p in sig.parameters.values(): if p.name in resolved_params: bound_args.arguments[p.name] = resolved_params.pop(p.name) + if p.annotation is SPEC: + bound_args.arguments[p.name] = config # pass all other config parameters into kwargs if present - if kwargs_par is not None: - bound_args.arguments[kwargs_par.name].update(resolved_params) + if kwargs_arg is not None: + bound_args.arguments[kwargs_arg.name].update(resolved_params) + bound_args.arguments[kwargs_arg.name][_LAST_DLT_CONFIG] = config # call the function with injected config return f(*bound_args.args, **bound_args.kwargs) @@ -88,6 +100,10 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: return decorator(func) +def last_config(**kwargs: Any) -> TConfiguration: + return kwargs[_LAST_DLT_CONFIG] + + def _get_spec_name_from_f(f: AnyFun) -> str: func_name = f.__qualname__.replace(".", "") # func qual name contains position in the module, separated by dots diff --git a/dlt/common/configuration/providers/__init__.py b/dlt/common/configuration/providers/__init__.py index 10d7b9b24a..42488f1b96 100644 --- a/dlt/common/configuration/providers/__init__.py +++ b/dlt/common/configuration/providers/__init__.py @@ -1,2 +1,3 @@ from .provider import Provider -from .environ import EnvironProvider \ No newline at end of file +from .environ import EnvironProvider +from .dictionary import DictionaryProvider \ No newline at end of file diff --git a/dlt/common/configuration/providers/dictionary.py b/dlt/common/configuration/providers/dictionary.py new file mode 100644 index 0000000000..906c8d6748 --- /dev/null +++ b/dlt/common/configuration/providers/dictionary.py @@ -0,0 +1,44 @@ +from contextlib import contextmanager +from typing import Any, Iterator, Optional, Type, Tuple + +from dlt.common.typing import StrAny + +from .provider import Provider + + +class DictionaryProvider(Provider): + + def __init__(self) -> None: + self._values: StrAny = {} + pass + + @property + def name(self) -> str: + return "Dictionary Provider" + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + full_path = namespaces + (key,) + full_key = "__".join(full_path) + node = self._values + try: + for k in full_path: + node = node[k] + return node, full_key + except KeyError: + return None, full_key + + @property + def supports_secrets(self) -> bool: + return True + + @property + def supports_namespaces(self) -> bool: + return True + + + @contextmanager + def values(self, v: StrAny) -> Iterator[None]: + p_values = self._values + self._values = v + yield + self._values = p_values diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py new file mode 100644 index 0000000000..6b433ce146 --- /dev/null +++ b/dlt/common/configuration/providers/toml.py @@ -0,0 +1,23 @@ +import os +import inspect +import dataclasses +import tomlkit +from inspect import Signature, Parameter +from typing import Any, List, Type +# from makefun import wraps +from functools import wraps + +from dlt.common.typing import DictStrAny, StrAny, TAny, TFun +from dlt.common.configuration import make_configuration, is_valid_hint +from dlt.common.configuration.specs import BaseConfiguration + + +def _read_toml(file_name: str) -> StrAny: + config_file_path = os.path.abspath(os.path.join(".", "experiments/.dlt", file_name)) + + if os.path.isfile(config_file_path): + with open(config_file_path, "r", encoding="utf-8") as f: + # use whitespace preserving parser + return tomlkit.load(f) + else: + return {} \ No newline at end of file diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index e0347d460f..6509298044 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -1,7 +1,5 @@ import ast import inspect -import sys -import semver import dataclasses from collections.abc import Mapping as C_Mapping from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, get_origin @@ -44,7 +42,7 @@ def make_configuration(config: TConfiguration, *, namespaces: Tuple[str, ...] = except ConfigEntryMissingException: if not accept_partial: raise - _add_module_version(config) + # _add_module_version(config) return config @@ -88,13 +86,13 @@ def serialize_value(value: Any) -> Any: return coerce_type("text", value_dt, value) -def _add_module_version(config: BaseConfiguration) -> None: - try: - v = sys._getframe(1).f_back.f_globals["__version__"] - semver.VersionInfo.parse(v) - setattr(config, "_version", v) # noqa: B010 - except KeyError: - pass +# def _add_module_version(config: BaseConfiguration) -> None: +# try: +# v = sys._getframe(1).f_back.f_globals["__version__"] +# semver.VersionInfo.parse(v) +# setattr(config, "_version", v) # noqa: B010 +# except KeyError: +# pass def _resolve_config_fields(config: BaseConfiguration, namespaces: Tuple[str, ...], accept_partial: bool) -> None: diff --git a/dlt/common/configuration/specs/schema_volume_configuration.py b/dlt/common/configuration/specs/schema_volume_configuration.py index 3b1d8c4df9..b2be72ed28 100644 --- a/dlt/common/configuration/specs/schema_volume_configuration.py +++ b/dlt/common/configuration/specs/schema_volume_configuration.py @@ -5,7 +5,7 @@ TSchemaFileFormat = Literal["json", "yaml"] -@configspec +@configspec(init=True) class SchemaVolumeConfiguration(BaseConfiguration): schema_volume_path: str = None # path to volume with default schemas import_schema_path: Optional[str] = None # import schema from external location diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 34fedce7fa..9382f572cf 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -77,7 +77,10 @@ def from_dict(cls, d: DictStrAny) -> "Schema": # bump version if modified utils.bump_version_if_modified(stored_schema) + return cls.from_stored_schema(stored_schema) + @classmethod + def from_stored_schema(cls, stored_schema: TStoredSchema) -> "Schema": # create new instance from dict self: Schema = cls(stored_schema["name"], normalizers=stored_schema.get("normalizers", None)) self._schema_tables = stored_schema.get("tables") or {} diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index d75009a20b..b6cf1d314d 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -164,12 +164,3 @@ def validate_file_name_component(name: str) -> None: # component cannot contain "." if "." in name: raise pathvalidate.error.InvalidCharError(reason="Component name cannot contain . (dots)") - pass - - # @staticmethod - # def get_file_stem(path: str) -> str: - # return Path(os.path.basename(path)).stem - - # @staticmethod - # def get_file_name(path: str) -> str: - # return Path(path).name diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index b74a6769de..3c04ec7439 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -1,12 +1,13 @@ -from typing import Dict +from typing import Any, Dict -from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.schema.schema import Schema from dlt.common.storages.schema_storage import SchemaStorage +from dlt.common.configuration.specs import SchemaVolumeConfiguration class LiveSchemaStorage(SchemaStorage): - def __init__(self, C: SchemaVolumeConfiguration, makedirs: bool = False) -> None: + + def __init__(self, C: SchemaVolumeConfiguration = None, makedirs: bool = False) -> None: self.live_schemas: Dict[str, Schema] = {} super().__init__(C, makedirs) @@ -34,7 +35,6 @@ def commit_live_schema(self, name: str) -> Schema: # if live schema exists and is modified then it must be used as an import schema live_schema = self.live_schemas.get(name) if live_schema and live_schema.stored_version_hash != live_schema.version_hash: - print("bumping and saving") live_schema.bump_version() if self.C.import_schema_path: # overwrite import schemas if specified diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 13c5dac4b3..74246f60b5 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -1,10 +1,11 @@ import os from os.path import join from pathlib import Path -from typing import Iterable, NamedTuple, Literal, Optional, Sequence, Set, get_args +from typing import Iterable, NamedTuple, Literal, Optional, Sequence, Set, get_args, overload from dlt.common import json, pendulum -from dlt.common.typing import DictStrAny, StrAny +from dlt.common.configuration.inject import with_config +from dlt.common.typing import ConfigValue, DictStrAny, StrAny from dlt.common.storages.file_storage import FileStorage from dlt.common.data_writers import TLoaderFileFormat, DataWriter from dlt.common.configuration.specs import LoadVolumeConfiguration @@ -41,23 +42,32 @@ class LoadStorage(DataItemStorage, VersionedStorage): ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) + @overload + def __init__(self, is_owner: bool, preferred_file_format: TLoaderFileFormat, supported_file_formats: Iterable[TLoaderFileFormat], config: LoadVolumeConfiguration) -> None: + ... + + @overload + def __init__(self, is_owner: bool, preferred_file_format: TLoaderFileFormat, supported_file_formats: Iterable[TLoaderFileFormat], config: LoadVolumeConfiguration = ConfigValue) -> None: + ... + + @with_config(spec=LoadVolumeConfiguration, namespaces=("load",)) def __init__( self, is_owner: bool, - C: LoadVolumeConfiguration, preferred_file_format: TLoaderFileFormat, - supported_file_formats: Iterable[TLoaderFileFormat] + supported_file_formats: Iterable[TLoaderFileFormat], + config: LoadVolumeConfiguration = ConfigValue ) -> None: if not LoadStorage.ALL_SUPPORTED_FILE_FORMATS.issuperset(supported_file_formats): raise TerminalValueError(supported_file_formats) if preferred_file_format not in supported_file_formats: raise TerminalValueError(preferred_file_format) self.supported_file_formats = supported_file_formats - self.delete_completed_jobs = C.delete_completed_jobs + self.config = config super().__init__( preferred_file_format, LoadStorage.STORAGE_VERSION, - is_owner, FileStorage(C.load_volume_path, "t", makedirs=is_owner) + is_owner, FileStorage(config.load_volume_path, "t", makedirs=is_owner) ) if is_owner: self.initialize_storage() @@ -181,7 +191,7 @@ def complete_load_package(self, load_id: str) -> None: load_path = self.get_package_path(load_id) has_failed_jobs = len(self.list_failed_jobs(load_id)) > 0 # delete load that does not contain failed jobs - if self.delete_completed_jobs and not has_failed_jobs: + if self.config.delete_completed_jobs and not has_failed_jobs: self.storage.delete_folder(load_path, recursively=True) else: completed_path = self.get_completed_package_path(load_id) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 446bec0063..b668f86aaa 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -1,10 +1,12 @@ -from typing import List, Sequence, NamedTuple +from typing import List, Sequence, NamedTuple, overload from itertools import groupby from pathlib import Path from dlt.common.storages.file_storage import FileStorage +from dlt.common.configuration import with_config from dlt.common.configuration.specs import NormalizeVolumeConfiguration from dlt.common.storages.versioned_storage import VersionedStorage +from dlt.common.typing import ConfigValue class TParsedNormalizeFileName(NamedTuple): @@ -18,9 +20,18 @@ class NormalizeStorage(VersionedStorage): STORAGE_VERSION = "1.0.0" EXTRACTED_FOLDER: str = "extracted" # folder within the volume where extracted files to be normalized are stored - def __init__(self, is_owner: bool, C: NormalizeVolumeConfiguration) -> None: - super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(C.normalize_volume_path, "t", makedirs=is_owner)) - self.CONFIG = C + @overload + def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration) -> None: + ... + + @overload + def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration = ConfigValue) -> None: + ... + + @with_config(spec=NormalizeVolumeConfiguration, namespaces=("normalize",)) + def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration = ConfigValue) -> None: + super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(config.normalize_volume_path, "t", makedirs=is_owner)) + self.config = config if is_owner: self.initialize_storage() diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index ed1f0f0513..305a901c33 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -1,13 +1,14 @@ import os import re import yaml -from typing import Iterator, List, Mapping +from typing import Iterator, List, Mapping, overload from dlt.common import json, logger +from dlt.common.configuration import with_config from dlt.common.configuration.specs import SchemaVolumeConfiguration, TSchemaFileFormat from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import Schema, verify_schema_hash -from dlt.common.typing import DictStrAny +from dlt.common.typing import DictStrAny, ConfigValue from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError @@ -17,7 +18,16 @@ class SchemaStorage(Mapping[str, Schema]): SCHEMA_FILE_NAME = "schema.%s" NAMED_SCHEMA_FILE_PATTERN = f"%s_{SCHEMA_FILE_NAME}" + @overload def __init__(self, C: SchemaVolumeConfiguration, makedirs: bool = False) -> None: + ... + + @overload + def __init__(self, C: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: + ... + + @with_config(spec=SchemaVolumeConfiguration, namespaces=("schema",)) + def __init__(self, C: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: self.C = C self.storage = FileStorage(C.schema_volume_path, makedirs=makedirs) @@ -152,3 +162,5 @@ def _file_name_in_store(self, name: str, fmt: TSchemaFileFormat) -> str: return SchemaStorage.NAMED_SCHEMA_FILE_PATTERN % (name, fmt) else: return SchemaStorage.SCHEMA_FILE_NAME % fmt + +SchemaStorage(makedirs=True) \ No newline at end of file diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 71058093d5..24bc16de9c 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -26,6 +26,7 @@ TSecretValue = NewType("TSecretValue", str) # represent secret value ie. coming from Kubernetes/Docker secrets or other providers TDataItem = Any # a single data item extracted from data source, normalized and loaded +ConfigValue: None = None TVariantBase = TypeVar("TVariantBase", covariant=True) TVariantRV = Tuple[str, Any] diff --git a/dlt/dbt_runner/configuration.py b/dlt/dbt_runner/configuration.py index 7516e6eff6..888c8b7993 100644 --- a/dlt/dbt_runner/configuration.py +++ b/dlt/dbt_runner/configuration.py @@ -7,8 +7,6 @@ from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials -from . import __version__ - @configspec class DBTRunnerConfiguration(PoolRunnerConfiguration): diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 8acf0e014f..0b9d96578e 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -8,8 +8,8 @@ from dlt.common.configuration.specs import NormalizeVolumeConfiguration -from experiments.pipeline.pipe import PipeIterator -from experiments.pipeline.sources import DltResource, DltSource +from dlt.extract.pipe import PipeIterator +from dlt.extract.sources import DltResource, DltSource class ExtractorStorage(DataItemStorage, NormalizeStorage): diff --git a/dlt/extract/sources.py b/dlt/extract/sources.py index 4b85646909..c4703b1d09 100644 --- a/dlt/extract/sources.py +++ b/dlt/extract/sources.py @@ -10,7 +10,7 @@ from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TPartialTableSchema, TTableSchema, TTableSchemaColumns, TWriteDisposition -from experiments.pipeline.pipe import FilterItem, Pipe, CreatePipeException, PipeIterator +from dlt.extract.pipe import FilterItem, Pipe, CreatePipeException, PipeIterator class TTableSchemaTemplate(TypedDict, total=False): diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index 797a82e99b..d7e9ef16bd 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -4,8 +4,6 @@ from dlt.common.configuration import configspec, make_configuration from dlt.common.configuration.specs import BaseConfiguration, PoolRunnerConfiguration, LoadVolumeConfiguration, TPoolType -from . import __version__ - @configspec class LoaderClientConfiguration(BaseConfiguration): client_type: str = None # which destination to load data to diff --git a/dlt/load/load.py b/dlt/load/load.py index ddcec05682..15a673181a 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -31,9 +31,9 @@ class Load(Runnable[ThreadPool]): job_counter: Counter = None job_wait_summary: Summary = None - def __init__(self, C: LoaderConfiguration, collector: CollectorRegistry, client_initial_values: StrAny = None, is_storage_owner: bool = False) -> None: - self.CONFIG = C - self.load_client_cls = self.import_client_cls(C.client_type, initial_values=client_initial_values) + def __init__(self, config: LoaderConfiguration, collector: CollectorRegistry, client_initial_values: StrAny = None, is_storage_owner: bool = False) -> None: + self.config = config + self.load_client_cls = self.import_client_cls(config.client_type, initial_values=client_initial_values) self.pool: ThreadPool = None self.load_storage: LoadStorage = self.create_storage(is_storage_owner) try: @@ -57,7 +57,6 @@ def import_client_cls(client_type: str, initial_values: StrAny = None) -> Type[J def create_storage(self, is_storage_owner: bool) -> LoadStorage: load_storage = LoadStorage( is_storage_owner, - self.CONFIG, self.load_client_cls.capabilities()["preferred_loader_file_format"], self.load_client_cls.capabilities()["supported_loader_file_formats"] ) @@ -113,7 +112,7 @@ def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJo # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs # TODO: combine files by providing a list of files pertaining to same table into job, so job must be # extended to accept a list - load_files = self.load_storage.list_new_jobs(load_id)[:self.CONFIG.workers] + load_files = self.load_storage.list_new_jobs(load_id)[:self.config.workers] file_count = len(load_files) if file_count == 0: logger.info(f"No new jobs found in {load_id}") @@ -208,11 +207,11 @@ def run(self, pool: ThreadPool) -> TRunMetrics: logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") # initialize analytical storage ie. create dataset required by passed schema with self.load_client_cls(schema) as client: - logger.info(f"Client {self.CONFIG.client_type} will start load") + logger.info(f"Client {self.config.client_type} will start load") client.initialize_storage() schema_update = self.load_storage.begin_schema_update(load_id) if schema_update: - logger.info(f"Client {self.CONFIG.client_type} will update schema to package schema") + logger.info(f"Client {self.config.client_type} will update schema to package schema") # TODO: this should rather generate an SQL job(s) to be executed PRE loading client.update_storage_schema() self.load_storage.commit_schema_update(load_id) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index eaf6f4116a..b365aded5a 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -179,11 +179,11 @@ def normalize(self, workers: int = 1, max_events_in_chunk: int = 100000) -> int: raise NotImplementedError("Do not use workers in interactive mode ie. in notebook") self._verify_normalize_instance() # set runtime parameters - self._normalize_instance.CONFIG.workers = workers + self._normalize_instance.config.workers = workers # switch to thread pool for single worker - self._normalize_instance.CONFIG.pool_type = "thread" if workers == 1 else "process" + self._normalize_instance.config.pool_type = "thread" if workers == 1 else "process" try: - ec = runner.run_pool(self._normalize_instance.CONFIG, self._normalize_instance) + ec = runner.run_pool(self._normalize_instance.config, self._normalize_instance) # in any other case we raise if runner exited with status failed if runner.LAST_RUN_METRICS.has_failed: raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) @@ -194,10 +194,10 @@ def normalize(self, workers: int = 1, max_events_in_chunk: int = 100000) -> int: def load(self, max_parallel_loads: int = 20) -> int: self._verify_loader_instance() - self._loader_instance.CONFIG.workers = max_parallel_loads + self._loader_instance.config.workers = max_parallel_loads self._loader_instance.load_client_cls.CONFIG.DEFAULT_SCHEMA_NAME = self.default_schema_name # type: ignore try: - ec = runner.run_pool(self._loader_instance.CONFIG, self._loader_instance) + ec = runner.run_pool(self._loader_instance.config, self._loader_instance) # in any other case we raise if runner exited with status failed if runner.LAST_RUN_METRICS.has_failed: raise PipelineStepFailed("load", self.last_run_exception, runner.LAST_RUN_METRICS) @@ -283,14 +283,14 @@ def sql_client(self, schema_name: str = None) -> SqlClientBase[Any]: if isinstance(c, SqlJobClientBase): return c.sql_client else: - raise SqlClientNotAvailable(self._loader_instance.CONFIG.client_type) + raise SqlClientNotAvailable(self._loader_instance.config.client_type) def run_in_pool(self, run_f: Callable[..., Any]) -> int: # internal runners should work in single mode - self._loader_instance.CONFIG.is_single_run = True - self._loader_instance.CONFIG.exit_on_exception = True - self._normalize_instance.CONFIG.is_single_run = True - self._normalize_instance.CONFIG.exit_on_exception = True + self._loader_instance.config.is_single_run = True + self._loader_instance.config.exit_on_exception = True + self._normalize_instance.config.is_single_run = True + self._normalize_instance.config.exit_on_exception = True def _run(_: Any) -> TRunMetrics: rv = run_f() diff --git a/experiments/pipeline/__init__.py b/experiments/pipeline/__init__.py index f7a07ad9dc..9db28049a4 100644 --- a/experiments/pipeline/__init__.py +++ b/experiments/pipeline/__init__.py @@ -6,4 +6,35 @@ # if name == 'y': # return 3 # raise AttributeError(f"module '{__name__}' has no attribute '{name}'") +import tempfile +from dlt.common.typing import TSecretValue, Any +from dlt.common.configuration import with_config + +from experiments.pipeline.configuration import PipelineConfiguration +from experiments.pipeline.pipeline import Pipeline + + +# @overload +# def configure(self, +# pipeline_name: str = None, +# working_dir: str = None, +# pipeline_secret: TSecretValue = None, +# drop_existing_data: bool = False, +# import_schema_path: str = None, +# export_schema_path: str = None, +# destination_name: str = None, +# log_level: str = "INFO" +# ) -> None: +# ... + + +@with_config(spec=PipelineConfiguration, auto_namespace=True) +def configure(pipeline_name: str = None, working_dir: str = None, pipeline_secret: TSecretValue = None, **kwargs: Any) -> Pipeline: + # if working_dir not provided use temp folder + if not working_dir: + working_dir = tempfile.gettempdir() + return Pipeline(pipeline_name, working_dir, pipeline_secret, kwargs["runtime"]) + +def run() -> Pipeline: + return configure().extract() \ No newline at end of file diff --git a/experiments/pipeline/configuration.py b/experiments/pipeline/configuration.py index 159f13c4c4..02109abf50 100644 --- a/experiments/pipeline/configuration.py +++ b/experiments/pipeline/configuration.py @@ -1,38 +1,18 @@ -import os -import inspect -import dataclasses -import tomlkit -from inspect import Signature, Parameter -from typing import Any, List, Type -# from makefun import wraps -from functools import wraps +from typing import Optional -from dlt.common.typing import DictStrAny, StrAny, TAny, TFun -from dlt.common.configuration import make_configuration, is_valid_hint -from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration -# _POS_PARAMETER_KINDS = (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, Parameter.VAR_POSITIONAL) +from dlt.common.typing import TSecretValue +from dlt.common.utils import uniq_id -def _read_toml(file_name: str) -> StrAny: - config_file_path = os.path.abspath(os.path.join(".", "experiments/.dlt", file_name)) - if os.path.isfile(config_file_path): - with open(config_file_path, "r", encoding="utf-8") as f: - # use whitespace preserving parser - return tomlkit.load(f) - else: - return {} - - -def get_config_from_toml(): - pass - - -def get_config(SPEC: Type[TAny], key: str = None, namespace: str = None, initial_value: Any = None, accept_partial: bool = False) -> TAny: - # TODO: implement key and namespace - return make_configuration(SPEC(), initial_value=initial_value, accept_partial=accept_partial) - - -def spec_from_dict(): - pass +@configspec +class PipelineConfiguration(BaseConfiguration): + working_dir: Optional[str] = None + pipeline_secret: Optional[TSecretValue] = None + runtime: RunConfiguration + def check_integrity(self) -> None: + if self.pipeline_secret: + self.pipeline_secret = uniq_id() diff --git a/experiments/pipeline/extract.py b/experiments/pipeline/extract.py deleted file mode 100644 index 8acf0e014f..0000000000 --- a/experiments/pipeline/extract.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -from typing import List - -from dlt.common.utils import uniq_id -from dlt.common.sources import TDirectDataItem, TDataItem -from dlt.common.schema import utils, TSchemaUpdate -from dlt.common.storages import NormalizeStorage, DataItemStorage -from dlt.common.configuration.specs import NormalizeVolumeConfiguration - - -from experiments.pipeline.pipe import PipeIterator -from experiments.pipeline.sources import DltResource, DltSource - - -class ExtractorStorage(DataItemStorage, NormalizeStorage): - EXTRACT_FOLDER = "extract" - - def __init__(self, C: NormalizeVolumeConfiguration) -> None: - # data item storage with jsonl with pua encoding - super().__init__("puae-jsonl", False, C) - self.initialize_storage() - - def initialize_storage(self) -> None: - self.storage.create_folder(ExtractorStorage.EXTRACT_FOLDER, exists_ok=True) - - def create_extract_id(self) -> str: - extract_id = uniq_id() - self.storage.create_folder(self._get_extract_path(extract_id)) - return extract_id - - def commit_extract_files(self, extract_id: str, with_delete: bool = True) -> None: - extract_path = self._get_extract_path(extract_id) - for file in self.storage.list_folder_files(extract_path, to_root=False): - from_file = os.path.join(extract_path, file) - to_file = os.path.join(NormalizeStorage.EXTRACTED_FOLDER, file) - if with_delete: - self.storage.atomic_rename(from_file, to_file) - else: - # create hardlink which will act as a copy - self.storage.link_hard(from_file, to_file) - if with_delete: - self.storage.delete_folder(extract_path, recursively=True) - - def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: - template = NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s") - return self.storage.make_full_path(os.path.join(self._get_extract_path(load_id), template)) - - def _get_extract_path(self, extract_id: str) -> str: - return os.path.join(ExtractorStorage.EXTRACT_FOLDER, extract_id) - - -def extract(source: DltSource, storage: ExtractorStorage) -> TSchemaUpdate: - dynamic_tables: TSchemaUpdate = {} - schema = source.schema - extract_id = storage.create_extract_id() - - def _write_item(table_name: str, item: TDirectDataItem) -> None: - # normalize table name before writing so the name match the name in schema - # note: normalize function should be cached so there's almost no penalty on frequent calling - # note: column schema is not required for jsonl writer used here - # TODO: consider dropping DLT_METADATA_FIELD in all items before writing, this however takes CPU time - # event.pop(DLT_METADATA_FIELD, None) # type: ignore - storage.write_data_item(extract_id, schema.name, schema.normalize_table_name(table_name), item, None) - - def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: - table_name = resource._table_name_hint_fun(item) - existing_table = dynamic_tables.get(table_name) - if existing_table is None: - dynamic_tables[table_name] = [resource.table_schema(item)] - else: - # quick check if deep table merge is required - if resource._table_has_other_dynamic_hints: - new_table = resource.table_schema(item) - # this merges into existing table in place - utils.merge_tables(existing_table[0], new_table) - else: - # if there are no other dynamic hints besides name then we just leave the existing partial table - pass - # write to storage with inferred table name - _write_item(table_name, item) - - # yield from all selected pipes - for pipe_item in PipeIterator.from_pipes(source.pipes): - # get partial table from table template - resource = source.resource_by_pipe(pipe_item.pipe) - if resource._table_name_hint_fun: - if isinstance(pipe_item.item, List): - for item in pipe_item.item: - _write_dynamic_table(resource, item) - else: - _write_dynamic_table(resource, pipe_item.item) - else: - # write item belonging to table with static name - _write_item(resource.name, pipe_item.item) - - # flush all buffered writers - storage.close_writers(extract_id) - storage.commit_extract_files(extract_id) - - # returns set of partial tables - return dynamic_tables - diff --git a/experiments/pipeline/pipe.py b/experiments/pipeline/pipe.py deleted file mode 100644 index da32a6c12f..0000000000 --- a/experiments/pipeline/pipe.py +++ /dev/null @@ -1,449 +0,0 @@ -import types -import asyncio -from asyncio import Future -from concurrent.futures import ThreadPoolExecutor -from copy import deepcopy -from threading import Thread -from typing import Optional, Sequence, Union, Callable, Iterable, Iterator, List, NamedTuple, Awaitable, Tuple, Type, TYPE_CHECKING -from dlt.common.configuration.inject import with_config -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec - -from dlt.common.typing import TDataItem -from dlt.common.sources import TDirectDataItem, TResolvableDataItem - -if TYPE_CHECKING: - TItemFuture = Future[TDirectDataItem] -else: - TItemFuture = Future - -from dlt.common.exceptions import DltException -from dlt.common.time import sleep - - -class PipeItem(NamedTuple): - item: TDirectDataItem - step: int - pipe: "Pipe" - - -class ResolvablePipeItem(NamedTuple): - # mypy unable to handle recursive types, ResolvablePipeItem should take itself in "item" - item: Union[TResolvableDataItem, Iterator[TResolvableDataItem]] - step: int - pipe: "Pipe" - - -class FuturePipeItem(NamedTuple): - item: TItemFuture - step: int - pipe: "Pipe" - - -class SourcePipeItem(NamedTuple): - item: Union[Iterator[TResolvableDataItem], Iterator[ResolvablePipeItem]] - step: int - pipe: "Pipe" - - -# pipeline step may be iterator of data items or mapping function that returns data item or another iterator -TPipeStep = Union[ - Iterable[TResolvableDataItem], - Iterator[TResolvableDataItem], - Callable[[TDirectDataItem], TResolvableDataItem], - Callable[[TDirectDataItem], Iterator[TResolvableDataItem]], - Callable[[TDirectDataItem], Iterator[ResolvablePipeItem]] -] - - -class ForkPipe: - def __init__(self, pipe: "Pipe", step: int = -1) -> None: - self._pipes: List[Tuple["Pipe", int]] = [] - self.add_pipe(pipe, step) - - def add_pipe(self, pipe: "Pipe", step: int = -1) -> None: - if pipe not in self._pipes: - self._pipes.append((pipe, step)) - - def has_pipe(self, pipe: "Pipe") -> bool: - return pipe in [p[0] for p in self._pipes] - - def __call__(self, item: TDirectDataItem) -> Iterator[ResolvablePipeItem]: - for i, (pipe, step) in enumerate(self._pipes): - _it = item if i == 0 else deepcopy(item) - # always start at the beginning - yield ResolvablePipeItem(_it, step, pipe) - - -class FilterItem: - def __init__(self, filter_f: Callable[[TDataItem], bool]) -> None: - self._filter_f = filter_f - - def __call__(self, item: TDirectDataItem) -> Optional[TDirectDataItem]: - # item may be a list TDataItem or a single TDataItem - if isinstance(item, list): - item = [i for i in item if self._filter_f(i)] - if not item: - # item was fully consumed by the filter - return None - return item - else: - return item if self._filter_f(item) else None - - -class Pipe: - def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = None) -> None: - self.name = name - self._steps: List[TPipeStep] = steps or [] - self._backup_steps: List[TPipeStep] = None - self._pipe_id = f"{name}_{id(self)}" - self.parent = parent - - @classmethod - def from_iterable(cls, name: str, gen: Union[Iterable[TResolvableDataItem], Iterator[TResolvableDataItem]]) -> "Pipe": - if isinstance(gen, Iterable): - gen = iter(gen) - return cls(name, [gen]) - - @property - def head(self) -> TPipeStep: - return self._steps[0] - - @property - def tail(self) -> TPipeStep: - return self._steps[-1] - - @property - def steps(self) -> List[TPipeStep]: - return self._steps - - def __getitem__(self, i: int) -> TPipeStep: - return self._steps[i] - - def __len__(self) -> int: - return len(self._steps) - - def fork(self, child_pipe: "Pipe", child_step: int = -1) -> "Pipe": - if len(self._steps) == 0: - raise CreatePipeException("Cannot fork to empty pipe") - fork_step = self.tail - if not isinstance(fork_step, ForkPipe): - fork_step = ForkPipe(child_pipe, child_step) - self.add_step(fork_step) - else: - if not fork_step.has_pipe(child_pipe): - fork_step.add_pipe(child_pipe, child_step) - return self - - def clone(self) -> "Pipe": - p = Pipe(self.name, self._steps.copy(), self.parent) - # clone shares the id with the original - p._pipe_id = self._pipe_id - return p - - # def backup(self) -> None: - # if self.has_backup: - # raise PipeBackupException("Pipe backup already exists, restore pipe first") - # self._backup_steps = self._steps.copy() - - # @property - # def has_backup(self) -> bool: - # return self._backup_steps is not None - - - # def restore(self) -> None: - # if not self.has_backup: - # raise PipeBackupException("No pipe backup to restore") - # self._steps = self._backup_steps - # self._backup_steps = None - - def add_step(self, step: TPipeStep) -> "Pipe": - if len(self._steps) == 0 and self.parent is None: - # first element must be iterable or iterator - if not isinstance(step, (Iterable, Iterator)): - raise CreatePipeException("First step of independent pipe must be Iterable or Iterator") - else: - if isinstance(step, Iterable): - step = iter(step) - self._steps.append(step) - else: - if isinstance(step, (Iterable, Iterator)): - if self.parent is not None: - raise CreatePipeException("Iterable or Iterator cannot be a step in dependent pipe") - else: - raise CreatePipeException("Iterable or Iterator can only be a first step in independent pipe") - if not callable(step): - raise CreatePipeException("Pipe step must be a callable taking exactly one data item as input") - self._steps.append(step) - return self - - def full_pipe(self) -> "Pipe": - if self.parent: - pipe = self.parent.full_pipe().steps - else: - pipe = [] - - # return pipe with resolved dependencies - pipe.extend(self._steps) - return Pipe(self.name, pipe) - - def __repr__(self) -> str: - return f"Pipe {self.name} ({self._pipe_id}) at {id(self)}" - - -class PipeIterator(Iterator[PipeItem]): - - @configspec - class PipeIteratorConfiguration: - max_parallel_items: int = 100 - worker_threads: int = 5 - futures_poll_interval: float = 0.01 - - - def __init__(self, max_parallel_items: int, worker_threads, futures_poll_interval: float) -> None: - self.max_parallel_items = max_parallel_items - self.worker_threads = worker_threads - self.futures_poll_interval = futures_poll_interval - - self._async_pool: asyncio.AbstractEventLoop = None - self._async_pool_thread: Thread = None - self._thread_pool: ThreadPoolExecutor = None - self._sources: List[SourcePipeItem] = [] - self._futures: List[FuturePipeItem] = [] - - @classmethod - @with_config(spec=PipeIteratorConfiguration) - def from_pipe(cls, pipe: Pipe, *, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": - if pipe.parent: - pipe = pipe.full_pipe() - # head must be iterator - assert isinstance(pipe.head, Iterator) - # create extractor - extract = cls(max_parallelism, worker_threads, futures_poll_interval) - # add as first source - extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) - return extract - - @classmethod - @with_config(spec=PipeIteratorConfiguration) - def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, *, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": - extract = cls(max_parallelism, worker_threads, futures_poll_interval) - # clone all pipes before iterating (recursively) as we will fork them and this add steps - pipes = PipeIterator.clone_pipes(pipes) - - def _fork_pipeline(pipe: Pipe) -> None: - if pipe.parent: - # fork the parent pipe - pipe.parent.fork(pipe) - # make the parent yield by sending a clone of item to itself with position at the end - if yield_parents and pipe.parent in pipes: - # fork is last step of the pipe so it will yield - pipe.parent.fork(pipe.parent, len(pipe.parent) - 1) - _fork_pipeline(pipe.parent) - else: - # head of independent pipe must be iterator - assert isinstance(pipe.head, Iterator) - # add every head as source only once - if not any(i.pipe == pipe for i in extract._sources): - print("add to sources: " + pipe.name) - extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) - - - for pipe in reversed(pipes): - _fork_pipeline(pipe) - - return extract - - def __next__(self) -> PipeItem: - pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None - # __next__ should call itself to remove the `while` loop and continue clauses but that may lead to stack overflows: there's no tail recursion opt in python - # https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion (see Y combinator on how it could be emulated) - while True: - # do we need new item? - if pipe_item is None: - # process element from the futures - if len(self._futures) > 0: - pipe_item = self._resolve_futures() - # if none then take element from the newest source - if pipe_item is None: - pipe_item = self._get_source_item() - - if pipe_item is None: - if len(self._futures) == 0 and len(self._sources) == 0: - # no more elements in futures or sources - raise StopIteration() - else: - # if len(_sources - # print("waiting") - sleep(self.futures_poll_interval) - continue - - # if item is iterator, then add it as a new source - if isinstance(pipe_item.item, Iterator): - # print(f"adding iterable {item}") - self._sources.append(SourcePipeItem(pipe_item.item, pipe_item.step, pipe_item.pipe)) - pipe_item = None - continue - - if isinstance(pipe_item.item, Awaitable) or callable(pipe_item.item): - # do we have a free slot or one of the slots is done? - if len(self._futures) < self.max_parallel_items or self._next_future() >= 0: - if isinstance(pipe_item.item, Awaitable): - future = asyncio.run_coroutine_threadsafe(pipe_item.item, self._ensure_async_pool()) - else: - future = self._ensure_thread_pool().submit(pipe_item.item) - # print(future) - self._futures.append(FuturePipeItem(future, pipe_item.step, pipe_item.pipe)) # type: ignore - # pipe item consumed for now, request a new one - pipe_item = None - continue - else: - # print("maximum futures exceeded, waiting") - sleep(self.futures_poll_interval) - # try same item later - continue - - # if we are at the end of the pipe then yield element - # print(pipe_item) - if pipe_item.step == len(pipe_item.pipe) - 1: - # must be resolved - if isinstance(pipe_item.item, (Iterator, Awaitable)) or callable(pipe_item.pipe): - raise PipeItemProcessingError("Pipe item not processed", pipe_item) - # mypy not able to figure out that item was resolved - return pipe_item # type: ignore - - # advance to next step - step = pipe_item.pipe[pipe_item.step + 1] - assert callable(step) - item = step(pipe_item.item) - pipe_item = ResolvablePipeItem(item, pipe_item.step + 1, pipe_item.pipe) # type: ignore - - - def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: - # lazily create async pool is separate thread - if self._async_pool: - return self._async_pool - - def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: - asyncio.set_event_loop(loop) - loop.run_forever() - - self._async_pool = asyncio.new_event_loop() - self._async_pool_thread = Thread(target=start_background_loop, args=(self._async_pool,), daemon=True) - self._async_pool_thread.start() - - # start or return async pool - return self._async_pool - - def _ensure_thread_pool(self) -> ThreadPoolExecutor: - # lazily start or return thread pool - if self._thread_pool: - return self._thread_pool - - self._thread_pool = ThreadPoolExecutor(self.worker_threads) - return self._thread_pool - - def __enter__(self) -> "PipeIterator": - return self - - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType) -> None: - - def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: - loop.stop() - - for f, _, _ in self._futures: - if not f.done(): - f.cancel() - print("stopping loop") - if self._async_pool: - self._async_pool.call_soon_threadsafe(stop_background_loop, self._async_pool) - print("joining thread") - self._async_pool_thread.join() - self._async_pool = None - self._async_pool_thread = None - if self._thread_pool: - self._thread_pool.shutdown(wait=True) - self._thread_pool = None - - def _next_future(self) -> int: - return next((i for i, val in enumerate(self._futures) if val.item.done()), -1) - - def _resolve_futures(self) -> ResolvablePipeItem: - # no futures at all - if len(self._futures) == 0: - return None - - # anything done? - idx = self._next_future() - if idx == -1: - # nothing done - return None - - future, step, pipe = self._futures.pop(idx) - - if future.cancelled(): - # get next future - return self._resolve_futures() - - if future.exception(): - raise future.exception() - - return ResolvablePipeItem(future.result(), step, pipe) - - def _get_source_item(self) -> ResolvablePipeItem: - # no more sources to iterate - if len(self._sources) == 0: - return None - - # get items from last added iterator, this makes the overall Pipe as close to FIFO as possible - gen, step, pipe = self._sources[-1] - try: - item = next(gen) - # full pipe item may be returned, this is used by ForkPipe step - # to redirect execution of an item to another pipe - if isinstance(item, ResolvablePipeItem): - return item - else: - # keep the item assigned step and pipe - return ResolvablePipeItem(item, step, pipe) - except StopIteration: - # remove empty iterator and try another source - self._sources.pop() - return self._get_source_item() - - @staticmethod - def clone_pipes(pipes: Sequence[Pipe]) -> Sequence[Pipe]: - # will clone the pipes including the dependent ones - cloned_pipes = [p.clone() for p in pipes] - cloned_pairs = {id(p): c for p, c in zip(pipes, cloned_pipes)} - - for clone in cloned_pipes: - while True: - if not clone.parent: - break - # if already a clone - if clone.parent in cloned_pairs.values(): - break - # clone if parent pipe not yet cloned - if id(clone.parent) not in cloned_pairs: - print("cloning:" + clone.parent.name) - cloned_pairs[id(clone.parent)] = clone.parent.clone() - # replace with clone - print(f"replace depends on {clone.name} to {clone.parent.name}") - clone.parent = cloned_pairs[id(clone.parent)] - # recurr with clone - clone = clone.parent - - return cloned_pipes - - -class PipeException(DltException): - pass - - -class CreatePipeException(PipeException): - pass - - -class PipeItemProcessingError(PipeException): - pass - diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py index 1513929e1b..faa642be06 100644 --- a/experiments/pipeline/pipeline.py +++ b/experiments/pipeline/pipeline.py @@ -31,7 +31,7 @@ from experiments.pipeline.configuration import get_config from experiments.pipeline.exceptions import PipelineConfigMissing, PipelineConfiguredException, MissingDependencyException, PipelineStepFailed -from experiments.pipeline.sources import DltSource, TResolvableDataItem +from dlt.extract.sources import DltSource, TResolvableDataItem TConnectionString = NewType("TConnectionString", str) @@ -53,63 +53,23 @@ class TPipelineState(TypedDict): # sources: Dict[str, TSourceState] -class PipelineConfiguration(RunConfiguration): - WORKING_DIR: Optional[str] = None - PIPELINE_SECRET: Optional[TSecretValue] = None - DROP_EXISTING_DATA: bool = False - - def check_integrity(self) -> None: - if self.PIPELINE_SECRET: - self.PIPELINE_SECRET = uniq_id() - - class Pipeline: - ACTIVE_INSTANCE: "Pipeline" = None STATE_FILE = "state.json" - def __new__(cls: Type["Pipeline"]) -> "Pipeline": - cls.ACTIVE_INSTANCE = super().__new__(cls) - return cls.ACTIVE_INSTANCE - - def __init__(self): - # pipeline is not configured yet - # self.is_configured = False - # self.pipeline_name: str = None - # self.pipeline_secret: str = None - # self.default_schema_name: str = None - # self.default_dataset_name: str = None - # self.working_dir: str = None - # self.is_transient: bool = None - self.CONFIG: Type[PipelineConfiguration] = None + def __init__(self, pipeline_name: str, working_dir: str, pipeline_secret: TSecretValue, runtime: RunConfiguration): + self.pipeline_name = pipeline_name + self.working_dir = working_dir + self.pipeline_secret = pipeline_secret + self.runtime_config = runtime self.root_folder: str = None - self._initial_values: DictStrAny = {} + # self._initial_values: DictStrAny = {} self._state: TPipelineState = {} self._pipeline_storage: FileStorage = None self._extractor_storage: ExtractorStorageBase = None self._schema_storage: LiveSchemaStorage = None - def only_not_configured(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - if self.CONFIG: - raise PipelineConfiguredException(f.__name__) - return f(self, *args, **kwargs) - - return _wrap - - def maybe_default_config(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - if not self.CONFIG: - self.configure() - return f(self, *args, **kwargs) - - return _wrap - def with_state_sync(f: TFun) -> TFun: @wraps(f) @@ -130,42 +90,15 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: return _wrap - - @overload - def configure(self, - pipeline_name: str = None, - working_dir: str = None, - pipeline_secret: TSecretValue = None, - drop_existing_data: bool = False, - import_schema_path: str = None, - export_schema_path: str = None, - destination_name: str = None, - log_level: str = "INFO" - ) -> None: - ... - - - @only_not_configured @with_state_sync - def configure(self, **kwargs: Any) -> None: - # keep the locals to be able to initialize configs at any time - self._initial_values.update(**kwargs) - # resolve pipeline configuration - self.CONFIG = self._get_config(PipelineConfiguration) - - # use system temp folder if not specified - if not self.CONFIG.WORKING_DIR: - self.CONFIG.WORKING_DIR = tempfile.gettempdir() - self.root_folder = os.path.join(self.CONFIG.WORKING_DIR, self.CONFIG.pipeline_name) - self._set_common_initial_values() + def _configure(self) -> None: + # compute the folder that keeps all of the pipeline state + FileStorage.validate_file_name_component(self.pipeline_name) + self.root_folder = os.path.join(self.working_dir, self.pipeline_name) # create pipeline working dir self._pipeline_storage = FileStorage(self.root_folder, makedirs=False) - # remove existing pipeline if requested - if self._pipeline_storage.has_folder(".") and self.CONFIG.drop_existing_data: - self._pipeline_storage.delete_folder(".") - # restore pipeline if folder exists and contains state if self._pipeline_storage.has_file(Pipeline.STATE_FILE): self._restore_pipeline() @@ -173,7 +106,7 @@ def configure(self, **kwargs: Any) -> None: self._create_pipeline() # create schema storage - self._schema_storage = LiveSchemaStorage(self._get_config(SchemaVolumeConfiguration), makedirs=True) + self._schema_storage = LiveSchemaStorage(makedirs=True) # create extractor storage self._extractor_storage = ExtractorStorageBase( "1.0.0", @@ -185,6 +118,11 @@ def configure(self, **kwargs: Any) -> None: initialize_runner(self.CONFIG) + def drop() -> "Pipeline": + """Deletes existing pipeline state, schemas and drops datasets at the destination if present""" + pass + + def _get_config(self, spec: Type[TAny], accept_partial: bool = False) -> Type[TAny]: print(self._initial_values) return make_configuration(spec, spec, initial_values=self._initial_values, accept_partial=accept_partial) @@ -270,7 +208,7 @@ def normalize(self, dry_run: bool = False, workers: int = 1, max_events_in_chunk "POOL_TYPE": "thread" if workers == 1 else "process" }) try: - ec = runner.run_pool(normalize.CONFIG, normalize) + ec = runner.run_pool(normalize.config, normalize) # in any other case we raise if runner exited with status failed if runner.LAST_RUN_METRICS.has_failed: raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) @@ -302,7 +240,7 @@ def load( # then load print(locals()) load = self._configure_load(locals(), credentials) - runner.run_pool(load.CONFIG, load) + runner.run_pool(load.config, load) if runner.LAST_RUN_METRICS.has_failed: raise PipelineStepFailed("load", self.last_run_exception, runner.LAST_RUN_METRICS) diff --git a/experiments/pipeline/sources.py b/experiments/pipeline/sources.py deleted file mode 100644 index 4b85646909..0000000000 --- a/experiments/pipeline/sources.py +++ /dev/null @@ -1,219 +0,0 @@ -import contextlib -from copy import deepcopy -import inspect -from typing import AsyncIterable, AsyncIterator, Coroutine, Dict, Generator, Iterable, Iterator, List, Set, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any -from dlt.common.exceptions import DltException -from dlt.common.schema.utils import new_table - -from dlt.common.typing import TDataItem -from dlt.common.sources import TFunDataItemDynHint, TDirectDataItem -from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TPartialTableSchema, TTableSchema, TTableSchemaColumns, TWriteDisposition - -from experiments.pipeline.pipe import FilterItem, Pipe, CreatePipeException, PipeIterator - - -class TTableSchemaTemplate(TypedDict, total=False): - name: Union[str, TFunDataItemDynHint] - description: Union[str, TFunDataItemDynHint] - write_disposition: Union[TWriteDisposition, TFunDataItemDynHint] - # table_sealed: Optional[bool] - parent: Union[str, TFunDataItemDynHint] - columns: Union[TTableSchemaColumns, TFunDataItemDynHint] - - -class DltResourceSchema: - def __init__(self, name: str, table_schema_template: TTableSchemaTemplate = None): - # self.__name__ = name - self.name = name - self._table_name_hint_fun: TFunDataItemDynHint = None - self._table_has_other_dynamic_hints: bool = False - self._table_schema_template: TTableSchemaTemplate = None - self._table_schema: TPartialTableSchema = None - if table_schema_template: - self._set_template(table_schema_template) - - def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: - - if not self._table_schema_template: - # if table template is not present, generate partial table from name - if not self._table_schema: - self._table_schema = new_table(self.name) - return self._table_schema - - def _resolve_hint(hint: Union[Any, TFunDataItemDynHint]) -> Any: - if callable(hint): - return hint(item) - else: - return hint - - # if table template present and has dynamic hints, the data item must be provided - if self._table_name_hint_fun: - if item is None: - raise DataItemRequiredForDynamicTableHints(self.name) - else: - cloned_template = deepcopy(self._table_schema_template) - return cast(TPartialTableSchema, {k: _resolve_hint(v) for k, v in cloned_template.items()}) - else: - return cast(TPartialTableSchema, self._table_schema_template) - - def _set_template(self, table_schema_template: TTableSchemaTemplate) -> None: - # if "name" is callable in the template then the table schema requires actual data item to be inferred - name_hint = table_schema_template.get("name") - if callable(name_hint): - self._table_name_hint_fun = name_hint - # check if any other hints in the table template should be inferred from data - self._table_has_other_dynamic_hints = any(callable(v) for k, v in table_schema_template.items() if k != "name") - - if self._table_has_other_dynamic_hints and not self._table_name_hint_fun: - raise InvalidTableSchemaTemplate("Table name must be a function if any other table hint is a function") - self._table_schema_template = table_schema_template - - -class DltResource(Iterable[TDirectDataItem], DltResourceSchema): - def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate): - self.name = pipe.name - self._pipe = pipe - super().__init__(self.name, table_schema_template) - - @classmethod - def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSchemaTemplate = None) -> "DltResource": - # call functions assuming that they do not take any parameters, typically they are generator functions - if callable(data): - data = data() - - if isinstance(data, DltResource): - return data - - if isinstance(data, Pipe): - return cls(data, table_schema_template) - - # several iterable types are not allowed and must be excluded right away - if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): - raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) - - # create resource from iterator or iterable - if isinstance(data, (Iterable, Iterator)): - if inspect.isgenerator(data): - name = name or data.__name__ - else: - name = name or None - if not name: - raise ResourceNameRequired("The DltResource name was not provide or could not be inferred.") - pipe = Pipe.from_iterable(name, data) - return cls(pipe, table_schema_template) - - # some other data type that is not supported - raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) - - - def select(self, *table_names: Iterable[str]) -> "DltResource": - if not self._table_name_hint_fun: - raise CreatePipeException("Table name is not dynamic, table selection impossible") - - def _filter(item: TDataItem) -> bool: - return self._table_name_hint_fun(item) in table_names - - # add filtering function at the end of pipe - self._pipe.add_step(FilterItem(_filter)) - return self - - def map(self) -> None: - raise NotImplementedError() - - def flat_map(self) -> None: - raise NotImplementedError() - - def filter(self) -> None: - raise NotImplementedError() - - def __iter__(self) -> Iterator[TDirectDataItem]: - return map(lambda item: item.item, PipeIterator.from_pipe(self._pipe)) - - def __repr__(self) -> str: - return f"DltResource {self.name} ({self._pipe._pipe_id}) at {id(self)}" - - -class DltSource(Iterable[TDirectDataItem]): - def __init__(self, schema: Schema, resources: Sequence[DltResource] = None) -> None: - self.name = schema.name - self._schema = schema - self._resources: List[DltResource] = list(resources or []) - self._enabled_resource_names: Set[str] = set(r.name for r in self._resources) - - @classmethod - def from_data(cls, schema: Schema, data: Any) -> "DltSource": - # creates source from various forms of data - if isinstance(data, DltSource): - return data - - # several iterable types are not allowed and must be excluded right away - if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): - raise InvalidSourceDataType("Invalid data type for DltSource", type(data)) - - # in case of sequence, enumerate items and convert them into resources - if isinstance(data, Sequence): - resources = [DltResource.from_data(i) for i in data] - else: - resources = [DltResource.from_data(data)] - - return cls(schema, resources) - - - def __getitem__(self, name: str) -> List[DltResource]: - if name not in self._enabled_resource_names: - raise KeyError(name) - return [r for r in self._resources if r.name == name] - - def resource_by_pipe(self, pipe: Pipe) -> DltResource: - # identify pipes by memory pointer - return next(r for r in self._resources if r._pipe._pipe_id is pipe._pipe_id) - - @property - def resources(self) -> Sequence[DltResource]: - return [r for r in self._resources if r.name in self._enabled_resource_names] - - @property - def pipes(self) -> Sequence[Pipe]: - return [r._pipe for r in self._resources if r.name in self._enabled_resource_names] - - @property - def schema(self) -> Schema: - return self._schema - - def discover_schema(self) -> Schema: - # extract tables from all resources and update internal schema - for r in self._resources: - # names must be normalized here - with contextlib.suppress(DataItemRequiredForDynamicTableHints): - partial_table = self._schema.normalize_table_identifiers(r.table_schema()) - self._schema.update_schema(partial_table) - return self._schema - - def select(self, *resource_names: str) -> "DltSource": - # make sure all selected resources exist - for name in resource_names: - self.__getitem__(name) - self._enabled_resource_names = set(resource_names) - return self - - - def __iter__(self) -> Iterator[TDirectDataItem]: - return map(lambda item: item.item, PipeIterator.from_pipes(self.pipes)) - - def __repr__(self) -> str: - return f"DltSource {self.name} at {id(self)}" - - -class DltSourceException(DltException): - pass - - -class DataItemRequiredForDynamicTableHints(DltException): - def __init__(self, resource_name: str) -> None: - self.resource_name = resource_name - super().__init__(f"Instance of Data Item required to generate table schema in resource {resource_name}") - - - -# class diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 1fef49012a..7b53438b11 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -14,9 +14,6 @@ from tests.utils import preserve_environ, add_config_dict_to_env from tests.common.configuration.utils import MockProvider, WithCredentialsConfiguration, WrongConfiguration, SecretConfiguration, NamespacedConfiguration, environment, mock_provider -# used to test version -__version__ = "1.0.5" - COERCIONS = { 'str_val': 'test string', 'int_val': 12345, @@ -298,18 +295,13 @@ class _SecretCredentials(RunConfiguration): for key in C: assert C[key] == expected_dict[key] # version is present as attr but not present in dict - assert hasattr(C, "_version") assert hasattr(C, "__is_resolved__") assert hasattr(C, "__namespace__") - with pytest.raises(KeyError): - C["_version"] - # set ops # update supported and non existing attributes are ignored C.update({"pipeline_name": "old pipe", "__version": "1.1.1"}) assert C.pipeline_name == "old pipe" == C["pipeline_name"] - assert C._version != "1.1.1" # delete is not supported with pytest.raises(KeyError): @@ -322,9 +314,6 @@ class _SecretCredentials(RunConfiguration): C["pipeline_name"] = "new pipe" assert C.pipeline_name == "new pipe" == C["pipeline_name"] - with pytest.raises(KeyError): - C["_version"] = "1.1.1" - def test_fields_with_no_default_to_null(environment: Any) -> None: # fields with no default are promoted to class attrs with none @@ -531,20 +520,6 @@ def test_accept_partial(environment: Any) -> None: assert C.is_partial() -def test_finds_version(environment: Any) -> None: - global __version__ - - v = __version__ - C = resolve.make_configuration(BaseConfiguration()) - assert C._version == v - try: - del globals()["__version__"] - C = resolve.make_configuration(BaseConfiguration()) - assert not hasattr(C, "_version") - finally: - __version__ = v - - def test_coercion_rules() -> None: with pytest.raises(ConfigEnvValueCannotBeCoercedException): coerce_single_value("key", "some string", int) diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index 7e9647c202..01876e1b9d 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -132,6 +132,30 @@ def test_set_defaults_for_positional_args() -> None: pass +def test_inject_spec_remainder_in_kwargs() -> None: + # if the wrapped func contains kwargs then all the fields from spec without matching func args must be injected in kwargs + pass + + +def test_inject_spec_in_kwargs() -> None: + # the resolved spec is injected in kwargs + pass + + +def test_resolved_spec_in_kwargs_pass_through() -> None: + # if last_config is in kwargs then use it and do not resolve it anew + pass + + +def test_inject_spec_into_argument_with_spec_type() -> None: + # if signature contains argument with type of SPEC, it gets injected there + pass + + +def test_initial_spec_from_arg_with_spec_type() -> None: + # if signature contains argument with type of SPEC, get its value to init SPEC (instead of calling the constructor()) + pass + def test_auto_derived_spec_type_name() -> None: diff --git a/tests/common/storages/test_normalize_storage.py b/tests/common/storages/test_normalize_storage.py index 18c5d1e601..8deb472140 100644 --- a/tests/common/storages/test_normalize_storage.py +++ b/tests/common/storages/test_normalize_storage.py @@ -28,19 +28,19 @@ def test_build_extracted_file_name() -> None: def test_full_migration_path() -> None: # create directory structure - s = NormalizeStorage(True, NormalizeVolumeConfiguration) + s = NormalizeStorage(True) # overwrite known initial version write_version(s.storage, "1.0.0") # must be able to migrate to current version - s = NormalizeStorage(True, NormalizeVolumeConfiguration) + s = NormalizeStorage(True) assert s.version == NormalizeStorage.STORAGE_VERSION def test_unknown_migration_path() -> None: # create directory structure - s = NormalizeStorage(True, NormalizeVolumeConfiguration) + s = NormalizeStorage(True) # overwrite known initial version write_version(s.storage, "10.0.0") # must be able to migrate to current version with pytest.raises(NoMigrationPathException): - NormalizeStorage(False, NormalizeVolumeConfiguration) + NormalizeStorage(False) diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 31ead5021d..95cb9406f8 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -19,25 +19,25 @@ @pytest.fixture def storage() -> SchemaStorage: - return init_storage() + return init_storage(SchemaVolumeConfiguration()) @pytest.fixture def synced_storage() -> SchemaStorage: # will be created in /schemas - return init_storage({"import_schema_path": TEST_STORAGE_ROOT + "/import", "export_schema_path": TEST_STORAGE_ROOT + "/import"}) + return init_storage(SchemaVolumeConfiguration(import_schema_path=TEST_STORAGE_ROOT + "/import", export_schema_path=TEST_STORAGE_ROOT + "/import")) @pytest.fixture def ie_storage() -> SchemaStorage: # will be created in /schemas - return init_storage({"import_schema_path": TEST_STORAGE_ROOT + "/import", "export_schema_path": TEST_STORAGE_ROOT + "/export"}) + return init_storage(SchemaVolumeConfiguration(import_schema_path=TEST_STORAGE_ROOT + "/import", export_schema_path=TEST_STORAGE_ROOT + "/export")) -def init_storage(initial: DictStrAny = None) -> SchemaStorage: - C = make_configuration(SchemaVolumeConfiguration(), initial_value=initial) +def init_storage(C: SchemaVolumeConfiguration) -> SchemaStorage: # use live schema storage for test which must be backward compatible with schema storage s = LiveSchemaStorage(C, makedirs=True) + assert C is s.C if C.export_schema_path: os.makedirs(C.export_schema_path, exist_ok=True) if C.import_schema_path: diff --git a/tests/conftest.py b/tests/conftest.py index 78498e0b64..e6a875c644 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,12 +14,14 @@ def pytest_configure(config): load_volume_configuration.LoadVolumeConfiguration.load_volume_path = os.path.join(test_storage_root, "load") normalize_volume_configuration.NormalizeVolumeConfiguration.normalize_volume_path = os.path.join(test_storage_root, "normalize") - if hasattr(normalize_volume_configuration.NormalizeVolumeConfiguration, "__init__"): - # delete __init__, otherwise it will not be recreated by dataclass - delattr(normalize_volume_configuration.NormalizeVolumeConfiguration, "__init__") - normalize_volume_configuration.NormalizeVolumeConfiguration = dataclasses.dataclass(normalize_volume_configuration.NormalizeVolumeConfiguration, init=True, repr=False) + # delete __init__, otherwise it will not be recreated by dataclass + delattr(normalize_volume_configuration.NormalizeVolumeConfiguration, "__init__") + normalize_volume_configuration.NormalizeVolumeConfiguration = dataclasses.dataclass(normalize_volume_configuration.NormalizeVolumeConfiguration, init=True, repr=False) schema_volume_configuration.SchemaVolumeConfiguration.schema_volume_path = os.path.join(test_storage_root, "schemas") + delattr(schema_volume_configuration.SchemaVolumeConfiguration, "__init__") + schema_volume_configuration.SchemaVolumeConfiguration = dataclasses.dataclass(schema_volume_configuration.SchemaVolumeConfiguration, init=True, repr=False) + assert run_configuration.RunConfiguration.config_files_storage_path == os.path.join(test_storage_root, "config/%s") assert run_configuration.RunConfiguration().config_files_storage_path == os.path.join(test_storage_root, "config/%s") diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 7a7979c5bf..618c7a2e1b 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -17,7 +17,7 @@ from dlt.load.configuration import configuration, LoaderConfiguration from dlt.load.dummy import client -from dlt.load import Load, __version__ +from dlt.load import Load from dlt.load.dummy.configuration import DummyClientConfiguration from tests.utils import clean_test_storage, init_logger @@ -41,10 +41,10 @@ def logger_autouse() -> None: def test_gen_configuration() -> None: load = setup_loader() - assert LoaderConfiguration in type(load.CONFIG).mro() + assert LoaderConfiguration in type(load.config).mro() # mock missing config values load = setup_loader(initial_values={"load_volume_path": LoaderConfiguration.load_volume_path}) - assert LoaderConfiguration in type(load.CONFIG).mro() + assert LoaderConfiguration in type(load.config).mro() def test_spool_job_started() -> None: @@ -221,7 +221,7 @@ def test_failed_loop() -> None: def test_completed_loop_with_delete_completed() -> None: load = setup_loader(initial_client_values={"completed_prob": 1.0}) - load.CONFIG.delete_completed_jobs = True + load.config.delete_completed_jobs = True load.load_storage = load.create_storage(is_storage_owner=False) assert_complete_job(load, load.load_storage.storage, should_delete_completed=True) @@ -278,10 +278,6 @@ def test_exceptions() -> None: raise AssertionError() -def test_version() -> None: - assert configuration({"client_type": "dummy"})._version == __version__ - - def assert_complete_job(load: Load, storage: FileStorage, should_delete_completed: bool = False) -> None: load_id, _ = prepare_load_package( load.load_storage, diff --git a/tests/utils.py b/tests/utils.py index adf954899c..56c76a642d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,9 +6,11 @@ import logging from os import environ -from dlt.common.configuration.providers import EnvironProvider +from dlt.common.configuration.container import Container +from dlt.common.configuration.providers import EnvironProvider, DictionaryProvider from dlt.common.configuration.resolve import make_configuration, serialize_value from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration +from dlt.common.configuration.specs.config_providers_configuration import ConfigProvidersListConfiguration from dlt.common.logger import init_logging_from_config from dlt.common.storages import FileStorage from dlt.common.schema import Schema @@ -18,6 +20,12 @@ TEST_STORAGE_ROOT = "_storage" +# add test dictionary provider +TEST_DICT_CONFIG_PROVIDER = DictionaryProvider() +providers_config = Container()[ConfigProvidersListConfiguration] +providers_config.providers.append(TEST_DICT_CONFIG_PROVIDER) + + class MockHttpResponse(): def __init__(self, status_code: int) -> None: @@ -69,12 +77,10 @@ def clean_test_storage(init_normalize: bool = False, init_loader: bool = False) storage.create_folder(".") if init_normalize: from dlt.common.storages import NormalizeStorage - from dlt.common.configuration.specs import NormalizeVolumeConfiguration - NormalizeStorage(True, NormalizeVolumeConfiguration) + NormalizeStorage(True) if init_loader: from dlt.common.storages import LoadStorage - from dlt.common.configuration.specs import LoadVolumeConfiguration - LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) return storage @@ -106,4 +112,5 @@ def assert_no_dict_key_starts_with(d: StrAny, key_prefix: str) -> None: skipifpypy = pytest.mark.skipif( platform.python_implementation() == "PyPy", reason="won't run in PyPy interpreter" -) \ No newline at end of file +) + From c4762545dc6594cf80e727fba79a0946a341e68b Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 16 Oct 2022 12:36:57 +0200 Subject: [PATCH 36/66] adds create pipeline examples --- experiments/pipeline/create_pipeline.md | 119 ++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 experiments/pipeline/create_pipeline.md diff --git a/experiments/pipeline/create_pipeline.md b/experiments/pipeline/create_pipeline.md new file mode 100644 index 0000000000..058331271c --- /dev/null +++ b/experiments/pipeline/create_pipeline.md @@ -0,0 +1,119 @@ +## Mockup code for generic template credentials + +This is a toml file, for BigQuery credentials destination and instruction how to add source credentials. + +I assume that new pipeline is accessing REST API + +```toml +# provide credentials to `taktile` source below, for example +# api_key = "api key to access taktile endpoint" + +[gcp_credentials] +client_email = +private_key = +project_id = +``` + +## Mockup code for taktile credentials + +```toml +taktile_api_key="96e6m3/OFSumLRG9mnIr" + +[gcp_credentials] +client_email = +private_key = +project_id = +``` + + +## Mockup code pipeline script template with nice UX + +This is a template made for BiqQuery destination and the source named `taktile`. This already proposes a nice structure for the code so the pipeline may be developed further. + + +```python +import requests +import dlt + +# The code below is an example of well structured pipeline +# @Ty if you want I can write more comments and explanations + +@dlt.source +def taktile_data(): + # retrieve credentials via DLT secrets + api_key = dlt.secrets["api_key"] + + # make a call to the endpoint with request library + resp = requests.get("https://example.com/data", headers={"Authorization": api_key"}) + resp.raise_for_status() + data = resp.json() + + # you may process the data here + + # return resource to be loaded into `data` table + return dlt.resource(data, name="data") + +dlt.run(taktile_data(), destination="bigquery") +``` + + +## Mockup code of taktile pipeline script with nice UX + +Example for the simplest ad hoc pipeline without any structure + +```python +import requests +import dlt + +resp = requests.get( + "https://taktile.com/api/v2/logs", + headers={"Authorization": dlt.secrets["taktile_api_key"]}) +resp.raise_for_status() +data = resp.json() + +dlt.run(data["result"], name="logs", destination="bigquery") +``` + +Example for endpoint returning only one resource: + +```python +import requests +import dlt + +@dlt.source +def taktile_data(): + resp = requests.get( + "https://taktile.com/api/v2/logs", + headers={"Authorization": dlt.secrets["taktile_api_key"]}) + resp.raise_for_status() + data = resp.json() + + return dlt.resource(data["result"], name="logs") + +dlt.run(taktile_data(), destination="bigquery") +``` + +With two resources: + +```python +import requests +import dlt + +@dlt.source +def taktile_data(): + resp = requests.get( + "https://taktile.com/api/v2/logs", + headers={"Authorization": dlt.secrets["taktile_api_key"]}) + resp.raise_for_status() + logs = resp.json()["results"] + + resp = requests.get( + "https://taktile.com/api/v2/decisions", + headers={"Authorization": dlt.secrets["taktile_api_key"]}) + resp.raise_for_status() + decisions = resp.json()["results"] + + return dlt.resource(logs, name="logs"), dlt.resource(decisions, name="decisions") + +dlt.run(taktile_data(), destination="bigquery") +``` From ec2059137c294a817964dce6186d50257c10fe90 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 18 Oct 2022 13:29:37 +0200 Subject: [PATCH 37/66] improves config resolution: last exception is preserved, resolved configs are not resolved again, external and injected namespaces are compacted, injected configs renamed to contexts --- dlt/common/configuration/__init__.py | 2 +- dlt/common/configuration/container.py | 51 ++++--- dlt/common/configuration/exceptions.py | 10 +- dlt/common/configuration/inject.py | 25 ++- .../configuration/providers/container.py | 25 +-- .../configuration/providers/dictionary.py | 6 +- dlt/common/configuration/providers/toml.py | 2 +- dlt/common/configuration/resolve.py | 144 ++++++++++++------ dlt/common/configuration/specs/__init__.py | 6 +- .../configuration/specs/base_configuration.py | 23 ++- .../specs/config_namespace_context.py | 14 ++ .../specs/config_providers_configuration.py | 25 --- .../specs/config_providers_context.py | 43 ++++++ .../specs/destination_capabilities_context.py | 25 +++ .../specs/load_volume_configuration.py | 2 +- .../specs/normalize_volume_configuration.py | 6 + .../specs/pool_runner_configuration.py | 5 +- .../configuration/specs/run_configuration.py | 2 +- .../specs/schema_volume_configuration.py | 6 +- .../configuration/test_configuration.py | 122 +++++++++++---- tests/common/configuration/test_container.py | 119 +++++++++------ .../configuration/test_environ_provider.py | 18 +-- tests/common/configuration/test_inject.py | 20 +++ tests/common/configuration/test_namespaces.py | 117 +++++++++----- tests/common/configuration/utils.py | 4 +- 25 files changed, 565 insertions(+), 257 deletions(-) create mode 100644 dlt/common/configuration/specs/config_namespace_context.py delete mode 100644 dlt/common/configuration/specs/config_providers_configuration.py create mode 100644 dlt/common/configuration/specs/config_providers_context.py create mode 100644 dlt/common/configuration/specs/destination_capabilities_context.py diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 65939a43c6..34e590ac49 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,5 +1,5 @@ from .specs.base_configuration import configspec, is_valid_hint # noqa: F401 -from .resolve import make_configuration # noqa: F401 +from .resolve import resolve_configuration, inject_namespace # noqa: F401 from .inject import with_config, last_config from .exceptions import ( # noqa: F401 diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index f28fa6e9fa..1f0d180c45 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -1,29 +1,22 @@ from contextlib import contextmanager from typing import Dict, Iterator, Type, TypeVar -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec -from dlt.common.configuration.exceptions import ContainerInjectableConfigurationMangled +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext +from dlt.common.configuration.exceptions import ContainerInjectableContextMangled, ContextDefaultCannotBeCreated - -@configspec -class ContainerInjectableConfiguration(BaseConfiguration): - """Base class for all configurations that may be injected from Container.""" - pass - - -TConfiguration = TypeVar("TConfiguration", bound=ContainerInjectableConfiguration) +TConfiguration = TypeVar("TConfiguration", bound=ContainerInjectableContext) class Container: _INSTANCE: "Container" = None - configurations: Dict[Type[ContainerInjectableConfiguration], ContainerInjectableConfiguration] + contexts: Dict[Type[ContainerInjectableContext], ContainerInjectableContext] def __new__(cls: Type["Container"]) -> "Container": if not cls._INSTANCE: cls._INSTANCE = super().__new__(cls) - cls._INSTANCE.configurations = {} + cls._INSTANCE.contexts = {} return cls._INSTANCE def __init__(self) -> None: @@ -31,32 +24,40 @@ def __init__(self) -> None: def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration: # return existing config object or create it from spec - if not issubclass(spec, ContainerInjectableConfiguration): - raise KeyError(f"{spec.__name__} is not injectable") + if not issubclass(spec, ContainerInjectableContext): + raise KeyError(f"{spec.__name__} is not a context") + + item = self.contexts.get(spec) + if item is None: + if spec.can_create_default: + item = spec() + self.contexts[spec] = item + else: + raise ContextDefaultCannotBeCreated(spec) - return self.configurations.setdefault(spec, spec()) # type: ignore + return item # type: ignore def __contains__(self, spec: Type[TConfiguration]) -> bool: - return spec in self.configurations + return spec in self.contexts @contextmanager - def injectable_configuration(self, config: TConfiguration) -> Iterator[TConfiguration]: + def injectable_context(self, config: TConfiguration) -> Iterator[TConfiguration]: spec = type(config) - previous_config: ContainerInjectableConfiguration = None - if spec in self.configurations: - previous_config = self.configurations[spec] + previous_config: ContainerInjectableContext = None + if spec in self.contexts: + previous_config = self.contexts[spec] # set new config and yield context try: - self.configurations[spec] = config + self.contexts[spec] = config yield config finally: # before setting the previous config for given spec, check if there was no overlapping modification - if self.configurations[spec] is config: + if self.contexts[spec] is config: # config is injected for spec so restore previous if previous_config is None: - del self.configurations[spec] + del self.contexts[spec] else: - self.configurations[spec] = previous_config + self.contexts[spec] = previous_config else: # value was modified in the meantime and not restored - raise ContainerInjectableConfigurationMangled(spec, self.configurations[spec], config) + raise ContainerInjectableContextMangled(spec, self.contexts[spec], config) diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index f64ce00c39..b0c5967440 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -94,9 +94,15 @@ def __init__(self, spec: Type[Any], initial_value_type: Type[Any]) -> None: super().__init__(f"Initial value of type {initial_value_type} is not valid for {spec.__name__}") -class ContainerInjectableConfigurationMangled(ConfigurationException): +class ContainerInjectableContextMangled(ConfigurationException): def __init__(self, spec: Type[Any], existing_config: Any, expected_config: Any) -> None: self.spec = spec self.existing_config = existing_config self.expected_config = expected_config - super().__init__(f"When restoring injectable config {spec.__name__}, instance {expected_config} was expected, instead instance {existing_config} was found.") + super().__init__(f"When restoring context {spec.__name__}, instance {expected_config} was expected, instead instance {existing_config} was found.") + + +class ContextDefaultCannotBeCreated(ConfigurationException): + def __init__(self, spec: Type[Any]) -> None: + self.spec = spec + super().__init__(f"Container cannot create the default value of context {spec.__name__}.") diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index dffeba2257..e7c0ed20c8 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -6,8 +6,9 @@ from inspect import Signature, Parameter from dlt.common.typing import StrAny, TFun, AnyFun -from dlt.common.configuration.resolve import make_configuration +from dlt.common.configuration.resolve import resolve_configuration, inject_namespace from dlt.common.configuration.specs.base_configuration import BaseConfiguration, is_valid_hint, configspec +from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext # [^.^_]+ splits by . or _ _SLEEPING_CAT_SPLIT = re.compile("[^.^_]+") @@ -16,16 +17,16 @@ @overload -def with_config(func: TFun, /, spec: Type[BaseConfiguration] = None, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> TFun: +def with_config(func: TFun, /, spec: Type[BaseConfiguration] = None, auto_namespace: bool = False, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> TFun: ... @overload -def with_config(func: None = ..., /, spec: Type[BaseConfiguration] = None, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]: +def with_config(func: None = ..., /, spec: Type[BaseConfiguration] = None, auto_namespace: bool = False, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]: ... -def with_config(func: Optional[AnyFun] = None, /, spec: Type[BaseConfiguration] = None, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]: +def with_config(func: Optional[AnyFun] = None, /, spec: Type[BaseConfiguration] = None, auto_namespace: bool = False, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]: namespace_f: Callable[[StrAny], str] = None # namespace may be a function from function arguments to namespace @@ -37,6 +38,8 @@ def decorator(f: TFun) -> TFun: sig: Signature = inspect.signature(f) kwargs_arg = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None) spec_arg: Parameter = None + pipeline_name_arg: Parameter = None + namespace_context = ConfigNamespacesContext() if spec is None: SPEC = _spec_from_signature(_get_spec_name_from_f(f), inspect.getmodule(f), sig, only_kw) @@ -50,6 +53,10 @@ def decorator(f: TFun) -> TFun: if p.annotation is SPEC: # if any argument has type SPEC then us it to take initial value spec_arg = p + if p.name == "pipeline_name" and auto_namespace: + # if argument has name pipeline_name and auto_namespace is used, use it to generate namespace context + pipeline_name_arg = p + @wraps(f, new_sig=sig) def _wrap(*args: Any, **kwargs: Any) -> Any: @@ -68,10 +75,14 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: # namespaces may be a string if isinstance(namespaces, str): namespaces = (namespaces,) - # resolve SPEC + # if one of arguments is spec the use it as initial value if spec_arg: config = bound_args.arguments.get(spec_arg.name, None) - config = make_configuration(config or SPEC(), namespaces=namespaces, initial_value=bound_args.arguments) + # resolve SPEC, also provide namespace_context with pipeline_name + if pipeline_name_arg: + namespace_context.pipeline_name = bound_args.arguments.get(pipeline_name_arg.name, None) + with inject_namespace(namespace_context): + config = resolve_configuration(config or SPEC(), namespaces=namespaces, initial_value=bound_args.arguments) resolved_params = dict(config) # overwrite or add resolved params for p in sig.parameters.values(): @@ -83,7 +94,7 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: if kwargs_arg is not None: bound_args.arguments[kwargs_arg.name].update(resolved_params) bound_args.arguments[kwargs_arg.name][_LAST_DLT_CONFIG] = config - # call the function with injected config + # call the function with resolved config return f(*bound_args.args, **bound_args.kwargs) return _wrap # type: ignore diff --git a/dlt/common/configuration/providers/container.py b/dlt/common/configuration/providers/container.py index 1fc40559b1..30699a40e5 100644 --- a/dlt/common/configuration/providers/container.py +++ b/dlt/common/configuration/providers/container.py @@ -1,26 +1,33 @@ +import contextlib from typing import Any, Optional, Type, Tuple from dlt.common.configuration.container import Container +from dlt.common.configuration.specs import ContainerInjectableContext from .provider import Provider -class ContainerProvider(Provider): +class ContextProvider(Provider): - NAME = "Injectable Configuration" + NAME = "Injectable Context" + + def __init__(self) -> None: + self.container = Container() @property def name(self) -> str: - return ContainerProvider.NAME + return ContextProvider.NAME def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: assert namespaces == () - # get container singleton - container = Container() - if hint in container: - return Container()[hint], hint.__name__ - else: - return None, str(hint) + + # only context is a valid hint + with contextlib.suppress(TypeError): + if issubclass(hint, ContainerInjectableContext): + # contexts without defaults will raise ContextDefaultCannotBeCreated + return self.container[hint], hint.__name__ + + return None, str(hint) @property def supports_secrets(self) -> bool: diff --git a/dlt/common/configuration/providers/dictionary.py b/dlt/common/configuration/providers/dictionary.py index 906c8d6748..252a2fb216 100644 --- a/dlt/common/configuration/providers/dictionary.py +++ b/dlt/common/configuration/providers/dictionary.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, Iterator, Optional, Type, Tuple +from typing import Any, ClassVar, Iterator, Optional, Type, Tuple from dlt.common.typing import StrAny @@ -8,13 +8,15 @@ class DictionaryProvider(Provider): + NAME: ClassVar[str] = "Dictionary Provider" + def __init__(self) -> None: self._values: StrAny = {} pass @property def name(self) -> str: - return "Dictionary Provider" + return self.NAME def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: full_path = namespaces + (key,) diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index 6b433ce146..ee5a4c75aa 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -8,7 +8,7 @@ from functools import wraps from dlt.common.typing import DictStrAny, StrAny, TAny, TFun -from dlt.common.configuration import make_configuration, is_valid_hint +from dlt.common.configuration import resolve_configuration, is_valid_hint from dlt.common.configuration.specs import BaseConfiguration diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 6509298044..e9ae20207c 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -1,48 +1,70 @@ import ast +from contextlib import _GeneratorContextManager import inspect -import dataclasses from collections.abc import Mapping as C_Mapping -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, get_origin +from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, get_origin from dlt.common import json, logger from dlt.common.typing import TSecretValue, is_optional_type, extract_inner_type from dlt.common.schema.utils import coerce_type, py_type_to_sc_type -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration, configspec -from dlt.common.configuration.container import Container, ContainerInjectableConfiguration -from dlt.common.configuration.specs.config_providers_configuration import ConfigProvidersListConfiguration -from dlt.common.configuration.providers.container import ContainerProvider +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration +from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersListContext +from dlt.common.configuration.providers.container import ContextProvider from dlt.common.configuration.exceptions import (LookupTrace, ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException, ValueNotSecretException, InvalidInitialValue) CHECK_INTEGRITY_F: str = "check_integrity" TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) -def make_configuration(config: TConfiguration, *, namespaces: Tuple[str, ...] = (), initial_value: Any = None, accept_partial: bool = False) -> TConfiguration: +def resolve_configuration(config: TConfiguration, *, namespaces: Tuple[str, ...] = (), initial_value: Any = None, accept_partial: bool = False) -> TConfiguration: if not isinstance(config, BaseConfiguration): raise ConfigurationWrongTypeException(type(config)) - # parse initial value if possible - if initial_value is not None: - try: - config.from_native_representation(initial_value) - except (NotImplementedError, ValueError): - # if parsing failed and initial_values is dict then apply - # TODO: we may try to parse with json here if str - if isinstance(initial_value, C_Mapping): - config.update(initial_value) - else: - raise InvalidInitialValue(type(config), type(initial_value)) + return _resolve_configuration(config, namespaces, (), initial_value, accept_partial) + +def _resolve_configuration( + config: TConfiguration, + explicit_namespaces: Tuple[str, ...], + embedded_namespaces: Tuple[str, ...], + initial_value: Any, accept_partial: bool + ) -> TConfiguration: + # do not resolve twice + if config.is_resolved(): + return config + + config.__exception__ = None try: - _resolve_config_fields(config, namespaces, accept_partial) - _check_configuration_integrity(config) - # full configuration was resolved - config.__is_resolved__ = True - except ConfigEntryMissingException: - if not accept_partial: - raise - # _add_module_version(config) + # parse initial value if possible + if initial_value is not None: + try: + config.from_native_representation(initial_value) + except (NotImplementedError, ValueError): + # if parsing failed and initial_values is dict then apply + # TODO: we may try to parse with json here if str + if isinstance(initial_value, C_Mapping): + config.update(initial_value) + else: + raise InvalidInitialValue(type(config), type(initial_value)) + + try: + _resolve_config_fields(config, explicit_namespaces, embedded_namespaces, accept_partial) + _check_configuration_integrity(config) + # full configuration was resolved + config.__is_resolved__ = True + except ConfigEntryMissingException as cm_ex: + if not accept_partial: + raise + else: + # store the ConfigEntryMissingException to have full info on traces of missing fields + config.__exception__ = cm_ex + except Exception as ex: + # store the exception that happened in the resolution process + config.__exception__ = ex + raise return config @@ -86,6 +108,26 @@ def serialize_value(value: Any) -> Any: return coerce_type("text", value_dt, value) +def inject_namespace(namespace_context: ConfigNamespacesContext, merge_existing: bool = True) -> Generator[ConfigNamespacesContext, None, None]: + """Adds `namespace` context to container, making it injectable. Optionally merges the context already in the container with the one provided + + Args: + namespace_context (ConfigNamespacesContext): Instance providing a pipeline name and namespace context + merge_existing (bool, optional): Gets `pipeline_name` and `namespaces` from existing context if they are not provided in `namespace` argument. Defaults to True. + + Yields: + Iterator[ConfigNamespacesContext]: Context manager with current namespace context + """ + container = Container() + existing_context = container[ConfigNamespacesContext] + + if merge_existing: + namespace_context.pipeline_name = namespace_context.pipeline_name or existing_context.pipeline_name + namespace_context.namespaces = namespace_context.namespaces or existing_context.namespaces + + return container.injectable_context(namespace_context) + + # def _add_module_version(config: BaseConfiguration) -> None: # try: # v = sys._getframe(1).f_back.f_globals["__version__"] @@ -95,7 +137,13 @@ def serialize_value(value: Any) -> Any: # pass -def _resolve_config_fields(config: BaseConfiguration, namespaces: Tuple[str, ...], accept_partial: bool) -> None: +def _resolve_config_fields( + config: BaseConfiguration, + explicit_namespaces: Tuple[str, ...], + embedded_namespaces: Tuple[str, ...], + accept_partial: bool + ) -> None: + fields = config.get_resolvable_fields() unresolved_fields: Dict[str, Sequence[LookupTrace]] = {} @@ -108,11 +156,13 @@ def _resolve_config_fields(config: BaseConfiguration, namespaces: Tuple[str, ... accept_partial = accept_partial or is_optional # if actual value is BaseConfiguration, resolve that instance if isinstance(current_value, BaseConfiguration): - # add key as innermost namespace - current_value = make_configuration(current_value, namespaces=namespaces + (key,), accept_partial=accept_partial) + # resolve only if not yet resolved otherwise just pass it + if not current_value.is_resolved(): + # add key as innermost namespace + current_value = _resolve_configuration(current_value, explicit_namespaces, embedded_namespaces + (key,), None, accept_partial) else: # resolve key value via active providers - value, traces = _resolve_single_field(key, hint, config.__namespace__, *namespaces) + value, traces = _resolve_single_field(key, hint, config.__namespace__, explicit_namespaces, embedded_namespaces) # log trace if logger.is_logging() and logger.log_level() == "DEBUG": @@ -133,7 +183,7 @@ def _resolve_config_fields(config: BaseConfiguration, namespaces: Tuple[str, ... current_value = value else: # create new instance and pass value from the provider as initial, add key to namespaces - current_value = make_configuration(hint(), namespaces=namespaces + (key,), initial_value=value or current_value, accept_partial=accept_partial) + current_value = _resolve_configuration(hint(), explicit_namespaces, embedded_namespaces + (key,), value or current_value, accept_partial) else: if value is not None: current_value = deserialize_value(key, value, hint) @@ -160,18 +210,18 @@ def _check_configuration_integrity(config: BaseConfiguration) -> None: c.__dict__[CHECK_INTEGRITY_F](config) -@configspec(init=True) -class ConfigNamespacesConfiguration(ContainerInjectableConfiguration): - pipeline_name: Optional[str] - namespaces: List[str] = dataclasses.field(default_factory=lambda: []) - - -def _resolve_single_field(key: str, hint: Type[Any], config_namespace: str, *namespaces: str) -> Tuple[Optional[Any], List[LookupTrace]]: +def _resolve_single_field( + key: str, + hint: Type[Any], + config_namespace: str, + explicit_namespaces: Tuple[str, ...], + embedded_namespaces: Tuple[str, ...] + ) -> Tuple[Optional[Any], List[LookupTrace]]: container = Container() # get providers from container - providers = container[ConfigProvidersListConfiguration].providers + providers = container[ConfigProvidersListContext].providers # get additional namespaces to look in from container - ctx_namespaces = container[ConfigNamespacesConfiguration] + namespaces_context = container[ConfigNamespacesContext] # pipeline_name = ctx_namespaces.pipeline_name # start looking from the top provider with most specific set of namespaces first @@ -181,7 +231,13 @@ def _resolve_single_field(key: str, hint: Type[Any], config_namespace: str, *nam def look_namespaces(pipeline_name: str = None) -> Any: for provider in providers: if provider.supports_namespaces: - ns = [*ctx_namespaces.namespaces, *namespaces] + # if explicit namespaces are provided, ignore the injected context + if explicit_namespaces: + ns = list(explicit_namespaces) + else: + ns = list(namespaces_context.namespaces) + # always extend with embedded namespaces + ns.extend(embedded_namespaces) else: # if provider does not support namespaces and pipeline name is set then ignore it if pipeline_name: @@ -204,7 +260,7 @@ def look_namespaces(pipeline_name: str = None) -> Any: full_ns = ns value, ns_key = provider.get_value(key, hint, *full_ns) # create trace, ignore container provider - if provider.name != ContainerProvider.NAME: + if provider.name != ContextProvider.NAME: traces.append(LookupTrace(provider.name, full_ns, ns_key, value)) # if secret is obtained from non secret provider, we must fail if value is not None and not provider.supports_secrets and (hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration))): @@ -219,8 +275,8 @@ def look_namespaces(pipeline_name: str = None) -> Any: ns.pop() # first try with pipeline name as namespace, if present - if ctx_namespaces.pipeline_name: - value = look_namespaces(ctx_namespaces.pipeline_name) + if namespaces_context.pipeline_name: + value = look_namespaces(namespaces_context.pipeline_name) # then without it if value is None: value = look_namespaces() diff --git a/dlt/common/configuration/specs/__init__.py b/dlt/common/configuration/specs/__init__.py index f6c33bceb2..c5efbc46a8 100644 --- a/dlt/common/configuration/specs/__init__.py +++ b/dlt/common/configuration/specs/__init__.py @@ -1,8 +1,10 @@ from .run_configuration import RunConfiguration # noqa: F401 -from .base_configuration import BaseConfiguration, CredentialsConfiguration # noqa: F401 +from .base_configuration import BaseConfiguration, CredentialsConfiguration, ContainerInjectableContext # noqa: F401 from .normalize_volume_configuration import NormalizeVolumeConfiguration # noqa: F401 from .load_volume_configuration import LoadVolumeConfiguration # noqa: F401 from .schema_volume_configuration import SchemaVolumeConfiguration, TSchemaFileFormat # noqa: F401 from .pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 from .gcp_client_credentials import GcpClientCredentials # noqa: F401 -from .postgres_credentials import PostgresCredentials # noqa: F401 \ No newline at end of file +from .postgres_credentials import PostgresCredentials # noqa: F401 +from .destination_capabilities_context import DestinationCapabilitiesContext # noqa: F401 +from .config_namespace_context import ConfigNamespacesContext # noqa: F401 \ No newline at end of file diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 9895052e08..2f9b1119ef 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -1,7 +1,7 @@ +import inspect import contextlib import dataclasses - -from typing import Callable, Optional, Union, Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING, get_origin, overload +from typing import Callable, Optional, Union, Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING, get_origin, overload, ClassVar if TYPE_CHECKING: TDtcField = dataclasses.Field[Any] @@ -18,7 +18,10 @@ def is_valid_hint(hint: Type[Any]) -> bool: hint = get_origin(hint) or hint if hint is Any: return True - if issubclass(hint, BaseConfiguration): + if hint is ClassVar: + # class vars are skipped by dataclass + return True + if inspect.isclass(hint) and issubclass(hint, BaseConfiguration): return True with contextlib.suppress(TypeError): py_type_to_sc_type(hint) @@ -50,8 +53,9 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_impl")): setattr(cls, ann, None) # get all attributes without corresponding annotations - for att_name, att in cls.__dict__.items(): - if not callable(att) and not att_name.startswith(("__", "_abc_impl")): + for att_name, att_value in cls.__dict__.items(): + # skip callables, dunder names, class variables and some special names + if not callable(att_value) and not att_name.startswith(("__", "_abc_impl")): if att_name not in cls.__annotations__: raise ConfigFieldMissingTypeHintException(att_name, cls) hint = cls.__annotations__[att_name] @@ -74,6 +78,8 @@ class BaseConfiguration(MutableMapping[str, Any]): __is_resolved__: bool = dataclasses.field(default = False, init=False, repr=False) # namespace used by config providers when searching for keys __namespace__: str = dataclasses.field(default = None, init=False, repr=False) + # holds the exception that prevented the full resolution + __exception__: Exception = dataclasses.field(default = None, init=False, repr=False) def __init__(self) -> None: self.__ignore_set_unknown_keys = False @@ -163,3 +169,10 @@ class CredentialsConfiguration(BaseConfiguration): """Base class for all credentials. Credentials are configurations that may be stored only by providers supporting secrets.""" pass + +@configspec +class ContainerInjectableContext(BaseConfiguration): + """Base class for all configurations that may be injected from Container. Injectable configurations are called contexts""" + + # If True, `Container` is allowed to create default context instance, if none exists + can_create_default: ClassVar[bool] = True diff --git a/dlt/common/configuration/specs/config_namespace_context.py b/dlt/common/configuration/specs/config_namespace_context.py new file mode 100644 index 0000000000..5c4bbd2725 --- /dev/null +++ b/dlt/common/configuration/specs/config_namespace_context.py @@ -0,0 +1,14 @@ +from typing import List, Optional, Tuple, TYPE_CHECKING + +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec + + +@configspec(init=True) +class ConfigNamespacesContext(ContainerInjectableContext): + pipeline_name: Optional[str] + namespaces: Tuple[str, ...] = () + + if TYPE_CHECKING: + # provide __init__ signature when type checking + def __init__(self, pipeline_name:str = None, namespaces: Tuple[str, ...] = ()) -> None: + ... diff --git a/dlt/common/configuration/specs/config_providers_configuration.py b/dlt/common/configuration/specs/config_providers_configuration.py deleted file mode 100644 index a10c258a19..0000000000 --- a/dlt/common/configuration/specs/config_providers_configuration.py +++ /dev/null @@ -1,25 +0,0 @@ - - -from typing import List - -from dlt.common.configuration.providers import Provider -from dlt.common.configuration.container import ContainerInjectableConfiguration -from dlt.common.configuration.providers.environ import EnvironProvider -from dlt.common.configuration.providers.container import ContainerProvider -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec - - -@configspec -class ConfigProvidersListConfiguration(ContainerInjectableConfiguration): - providers: List[Provider] - - def __init__(self) -> None: - super().__init__() - # add default providers, ContainerProvider must be always first - it will provide injectable configs - self.providers = [ContainerProvider(), EnvironProvider()] - - -@configspec -class ConfigProvidersConfiguration(BaseConfiguration): - with_aws_secrets: bool = False - with_google_secrets: bool = False diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py new file mode 100644 index 0000000000..00dc9d7efb --- /dev/null +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -0,0 +1,43 @@ + + +from typing import List + +from dlt.common.configuration.providers import Provider +from dlt.common.configuration.providers.environ import EnvironProvider +from dlt.common.configuration.providers.container import ContextProvider +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, ContainerInjectableContext, configspec + + +@configspec +class ConfigProvidersListContext(ContainerInjectableContext): + """Injectable list of providers used by the configuration `resolve` module""" + providers: List[Provider] + + def __init__(self) -> None: + super().__init__() + # add default providers, ContextProvider must be always first - it will provide contexts + self.providers = [ContextProvider(), EnvironProvider()] + + def get_provider(self, name: str) -> Provider: + try: + return next(p for p in self.providers if p.name == name) + except StopIteration: + raise KeyError(name) + + def has_provider(self, name: str) -> bool: + try: + self.get_provider(name) + return True + except KeyError: + return False + + def add_provider(self, provider: Provider) -> None: + if self.has_provider(provider.name): + raise DuplicateProviderException(provider.name) + self.providers.append(provider) + + +@configspec +class ConfigProvidersConfiguration(BaseConfiguration): + with_aws_secrets: bool = False + with_google_secrets: bool = False diff --git a/dlt/common/configuration/specs/destination_capabilities_context.py b/dlt/common/configuration/specs/destination_capabilities_context.py new file mode 100644 index 0000000000..a5832f383b --- /dev/null +++ b/dlt/common/configuration/specs/destination_capabilities_context.py @@ -0,0 +1,25 @@ +from typing import List, ClassVar, Literal + +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec + +# known loader file formats +# jsonl - new line separated json documents +# puae-jsonl - internal extract -> normalize format bases on jsonl +# insert_values - insert SQL statements +TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values"] + + +@configspec(init=True) +class DestinationCapabilitiesContext(ContainerInjectableContext): + """Injectable destination capabilities required for many Pipeline stages ie. normalize""" + preferred_loader_file_format: TLoaderFileFormat + supported_loader_file_formats: List[TLoaderFileFormat] + max_identifier_length: int + max_column_length: int + max_query_length: int + is_max_query_length_in_bytes: bool + max_text_data_type_length: int + is_max_text_data_type_length_in_bytes: bool + + # do not allow to create default value, destination caps must be always explicitly inserted into container + can_create_default: ClassVar[bool] = False diff --git a/dlt/common/configuration/specs/load_volume_configuration.py b/dlt/common/configuration/specs/load_volume_configuration.py index b626622a24..3846b78bd9 100644 --- a/dlt/common/configuration/specs/load_volume_configuration.py +++ b/dlt/common/configuration/specs/load_volume_configuration.py @@ -1,7 +1,7 @@ from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec -@configspec +@configspec(init=True) class LoadVolumeConfiguration(BaseConfiguration): load_volume_path: str = None # path to volume where files to be loaded to analytical storage are stored delete_completed_jobs: bool = False # if set to true the folder with completed jobs will be deleted diff --git a/dlt/common/configuration/specs/normalize_volume_configuration.py b/dlt/common/configuration/specs/normalize_volume_configuration.py index e1f2946947..49aa40df40 100644 --- a/dlt/common/configuration/specs/normalize_volume_configuration.py +++ b/dlt/common/configuration/specs/normalize_volume_configuration.py @@ -1,6 +1,12 @@ +from typing import TYPE_CHECKING + from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec @configspec(init=True) class NormalizeVolumeConfiguration(BaseConfiguration): normalize_volume_path: str = None # path to volume where normalized loader files will be stored + + if TYPE_CHECKING: + def __init__(self, normalize_volume_path: str = None) -> None: + ... diff --git a/dlt/common/configuration/specs/pool_runner_configuration.py b/dlt/common/configuration/specs/pool_runner_configuration.py index 3e7962ed43..06a95ceff1 100644 --- a/dlt/common/configuration/specs/pool_runner_configuration.py +++ b/dlt/common/configuration/specs/pool_runner_configuration.py @@ -1,13 +1,12 @@ from typing import Literal, Optional -from dlt.common.configuration.specs.base_configuration import configspec -from dlt.common.configuration.specs.run_configuration import RunConfiguration +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec TPoolType = Literal["process", "thread", "none"] @configspec -class PoolRunnerConfiguration(RunConfiguration): +class PoolRunnerConfiguration(BaseConfiguration): pool_type: TPoolType = None # type of pool to run, must be set in derived configs workers: Optional[int] = None # how many threads/processes in the pool run_sleep: float = 0.5 # how long to sleep between runs with workload, seconds diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index 7e4c620b65..e19faf1116 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -8,7 +8,7 @@ @configspec class RunConfiguration(BaseConfiguration): - pipeline_name: Optional[str] = None # the name of the component + pipeline_name: Optional[str] = None sentry_dsn: Optional[str] = None # keep None to disable Sentry prometheus_port: Optional[int] = None # keep None to disable Prometheus log_format: str = '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}' diff --git a/dlt/common/configuration/specs/schema_volume_configuration.py b/dlt/common/configuration/specs/schema_volume_configuration.py index b2be72ed28..324b2e418f 100644 --- a/dlt/common/configuration/specs/schema_volume_configuration.py +++ b/dlt/common/configuration/specs/schema_volume_configuration.py @@ -1,4 +1,4 @@ -from typing import Optional, Literal +from typing import Optional, Literal, TYPE_CHECKING from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec @@ -12,3 +12,7 @@ class SchemaVolumeConfiguration(BaseConfiguration): export_schema_path: Optional[str] = None # export schema to external location external_schema_format: TSchemaFileFormat = "yaml" # format in which to expect external schema external_schema_format_remove_defaults: bool = True # remove default values when exporting schema + + if TYPE_CHECKING: + def __init__(self, schema_volume_path: str = None) -> None: + ... diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 7b53438b11..d6437c2fbc 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -121,6 +121,7 @@ class MockProdConfiguration(RunConfiguration): class FieldWithNoDefaultConfiguration(RunConfiguration): no_default: str + @configspec(init=True) class InstrumentedConfiguration(BaseConfiguration): head: str @@ -155,6 +156,11 @@ class EmbeddedOptionalConfiguration(BaseConfiguration): instrumented: Optional[InstrumentedConfiguration] +@configspec +class EmbeddedSecretConfiguration(BaseConfiguration): + secret: SecretConfiguration + + LongInteger = NewType("LongInteger", int) FirstOrderStr = NewType("FirstOrderStr", str) SecondOrderStr = NewType("SecondOrderStr", FirstOrderStr) @@ -172,22 +178,22 @@ def test_initial_config_state() -> None: def test_set_initial_config_value(environment: Any) -> None: # set from init method - C = resolve.make_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) + C = resolve.resolve_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) assert C.to_native_representation() == "h>a>b>he" # set from native form - C = resolve.make_configuration(InstrumentedConfiguration(), initial_value="h>a>b>he") + C = resolve.resolve_configuration(InstrumentedConfiguration(), initial_value="h>a>b>he") assert C.head == "h" assert C.tube == ["a", "b"] assert C.heels == "he" # set from dictionary - C = resolve.make_configuration(InstrumentedConfiguration(), initial_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) + C = resolve.resolve_configuration(InstrumentedConfiguration(), initial_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) assert C.to_native_representation() == "h>tu>be>xhe" def test_invalid_initial_config_value() -> None: # 2137 cannot be parsed and also is not a dict that can initialize the fields with pytest.raises(InvalidInitialValue) as py_ex: - resolve.make_configuration(InstrumentedConfiguration(), initial_value=2137) + resolve.resolve_configuration(InstrumentedConfiguration(), initial_value=2137) assert py_ex.value.spec is InstrumentedConfiguration assert py_ex.value.initial_value_type is int @@ -195,32 +201,32 @@ def test_invalid_initial_config_value() -> None: def test_check_integrity(environment: Any) -> None: with pytest.raises(RuntimeError): # head over hells - resolve.make_configuration(InstrumentedConfiguration(), initial_value="he>a>b>h") + resolve.resolve_configuration(InstrumentedConfiguration(), initial_value="he>a>b>h") def test_embedded_config(environment: Any) -> None: # resolve all embedded config, using initial value for instrumented config and initial dict for namespaced config - C = resolve.make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "instrumented": "h>tu>be>xhe", "namespaced": {"password": "pwd"}}) + C = resolve.resolve_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "instrumented": "h>tu>be>xhe", "namespaced": {"password": "pwd"}}) assert C.default == "set" assert C.instrumented.to_native_representation() == "h>tu>be>xhe" assert C.namespaced.password == "pwd" # resolve but providing values via env with custom_environ({"INSTRUMENTED": "h>tu>u>be>xhe", "DLT_TEST__PASSWORD": "passwd", "DEFAULT": "DEF"}): - C = resolve.make_configuration(EmbeddedConfiguration()) + C = resolve.resolve_configuration(EmbeddedConfiguration()) assert C.default == "DEF" assert C.instrumented.to_native_representation() == "h>tu>u>be>xhe" assert C.namespaced.password == "passwd" # resolve partial, partial is passed to embedded - C = resolve.make_configuration(EmbeddedConfiguration(), accept_partial=True) + C = resolve.resolve_configuration(EmbeddedConfiguration(), accept_partial=True) assert not C.__is_resolved__ assert not C.namespaced.__is_resolved__ assert not C.instrumented.__is_resolved__ # some are partial, some are not with custom_environ({"DLT_TEST__PASSWORD": "passwd"}): - C = resolve.make_configuration(EmbeddedConfiguration(), accept_partial=True) + C = resolve.resolve_configuration(EmbeddedConfiguration(), accept_partial=True) assert not C.__is_resolved__ assert C.namespaced.__is_resolved__ assert not C.instrumented.__is_resolved__ @@ -228,17 +234,17 @@ def test_embedded_config(environment: Any) -> None: # single integrity error fails all the embeds with custom_environ({"INSTRUMENTED": "he>tu>u>be>h"}): with pytest.raises(RuntimeError): - resolve.make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) + resolve.resolve_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) # part via env part via initial values with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): - C = resolve.make_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) + C = resolve.resolve_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) assert C.instrumented.to_native_representation() == "h>tu>u>be>he" def test_provider_values_over_initial(environment: Any) -> None: with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): - C = resolve.make_configuration(EmbeddedConfiguration(), initial_value={"instrumented": "h>tu>be>xhe"}, accept_partial=True) + C = resolve.resolve_configuration(EmbeddedConfiguration(), initial_value={"instrumented": "h>tu>be>xhe"}, accept_partial=True) assert C.instrumented.to_native_representation() == "h>tu>u>be>he" # parent configuration is not resolved assert not C.is_resolved() @@ -250,7 +256,7 @@ def test_provider_values_over_initial(environment: Any) -> None: def test_run_configuration_gen_name(environment: Any) -> None: - C = resolve.make_configuration(RunConfiguration()) + C = resolve.resolve_configuration(RunConfiguration()) assert C.pipeline_name.startswith("dlt_") @@ -278,7 +284,7 @@ class _SecretCredentials(RunConfiguration): assert dict(_SecretCredentials()) == expected_dict environment["SECRET_VALUE"] = "secret" - C = resolve.make_configuration(_SecretCredentials()) + C = resolve.resolve_configuration(_SecretCredentials()) expected_dict["secret_value"] = "secret" assert dict(C) == expected_dict @@ -346,7 +352,7 @@ class MultiConfiguration(MockProdConfiguration, ConfigurationWithOptionalTypes, def test_raises_on_unresolved_field(environment: Any) -> None: # via make configuration with pytest.raises(ConfigEntryMissingException) as cf_missing_exc: - resolve.make_configuration(WrongConfiguration()) + resolve.resolve_configuration(WrongConfiguration()) assert cf_missing_exc.value.spec_name == "WrongConfiguration" assert "NoneConfigVar" in cf_missing_exc.value.traces # has only one trace @@ -358,7 +364,7 @@ def test_raises_on_unresolved_field(environment: Any) -> None: def test_raises_on_many_unresolved_fields(environment: Any) -> None: # via make configuration with pytest.raises(ConfigEntryMissingException) as cf_missing_exc: - resolve.make_configuration(CoercionTestConfiguration()) + resolve.resolve_configuration(CoercionTestConfiguration()) assert cf_missing_exc.value.spec_name == "CoercionTestConfiguration" # get all fields that must be set val_fields = [f for f in CoercionTestConfiguration().get_resolvable_fields() if f.lower().endswith("_val")] @@ -374,11 +380,11 @@ def test_accepts_optional_missing_fields(environment: Any) -> None: C = ConfigurationWithOptionalTypes() assert not C.is_partial() # make optional config - resolve.make_configuration(ConfigurationWithOptionalTypes()) + resolve.resolve_configuration(ConfigurationWithOptionalTypes()) # make config with optional values - resolve.make_configuration(ProdConfigurationWithOptionalTypes(), initial_value={"int_val": None}) + resolve.resolve_configuration(ProdConfigurationWithOptionalTypes(), initial_value={"int_val": None}) # make config with optional embedded config - C = resolve.make_configuration(EmbeddedOptionalConfiguration()) + C = resolve.resolve_configuration(EmbeddedOptionalConfiguration()) # embedded config was not fully resolved assert not C.instrumented.__is_resolved__ assert not C.instrumented.is_resolved() @@ -395,7 +401,7 @@ def test_coercion_to_hint_types(environment: Any) -> None: add_config_dict_to_env(COERCIONS) C = CoercionTestConfiguration() - resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) for key in COERCIONS: assert getattr(C, key) == COERCIONS[key] @@ -434,7 +440,7 @@ def test_invalid_coercions(environment: Any) -> None: add_config_dict_to_env(INVALID_COERCIONS) for key, value in INVALID_COERCIONS.items(): try: - resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) except ConfigEnvValueCannotBeCoercedException as coerc_exc: # must fail exactly on expected value if coerc_exc.field_name != key: @@ -449,7 +455,7 @@ def test_excepted_coercions(environment: Any) -> None: C = CoercionTestConfiguration() add_config_dict_to_env(COERCIONS) add_config_dict_to_env(EXCEPTED_COERCIONS, overwrite_keys=True) - resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) for key in EXCEPTED_COERCIONS: assert getattr(C, key) == COERCED_EXCEPTIONS[key] @@ -473,18 +479,18 @@ class NoHintConfiguration(BaseConfiguration): NoHintConfiguration() -def test_make_configuration(environment: Any) -> None: +def test_resolve_configuration(environment: Any) -> None: # fill up configuration environment["NONECONFIGVAR"] = "1" - C = resolve.make_configuration(WrongConfiguration()) + C = resolve.resolve_configuration(WrongConfiguration()) assert C.__is_resolved__ assert C.NoneConfigVar == "1" def test_dataclass_instantiation(environment: Any) -> None: - # make_configuration works on instances of dataclasses and types are not modified + # resolve_configuration works on instances of dataclasses and types are not modified environment['SECRET_VALUE'] = "1" - C = resolve.make_configuration(SecretConfiguration()) + C = resolve.resolve_configuration(SecretConfiguration()) # auto derived type holds the value assert C.secret_value == "1" # base type is untouched @@ -496,7 +502,7 @@ def test_initial_values(environment: Any) -> None: environment["PIPELINE_NAME"] = "env name" environment["CREATED_VAL"] = "12837" # set initial values and allow partial config - C = resolve.make_configuration(CoercionTestConfiguration(), + C = resolve.resolve_configuration(CoercionTestConfiguration(), initial_value={"pipeline_name": "initial name", "none_val": type(environment), "created_val": 878232, "bytes_val": b"str"}, accept_partial=True ) @@ -513,7 +519,7 @@ def test_accept_partial(environment: Any) -> None: # modify original type WrongConfiguration.NoneConfigVar = None # that None value will be present in the instance - C = resolve.make_configuration(WrongConfiguration(), accept_partial=True) + C = resolve.resolve_configuration(WrongConfiguration(), accept_partial=True) assert C.NoneConfigVar is None # partial resolution assert not C.__is_resolved__ @@ -582,17 +588,73 @@ def test_secret_value_not_secret_provider(mock_provider: MockProvider) -> None: # TSecretValue will fail with pytest.raises(ValueNotSecretException) as py_ex: - resolve.make_configuration(SecretConfiguration(), namespaces=("mock",)) + resolve.resolve_configuration(SecretConfiguration(), namespaces=("mock",)) assert py_ex.value.provider_name == "Mock Provider" assert py_ex.value.key == "-secret_value" # anything derived from CredentialsConfiguration will fail with pytest.raises(ValueNotSecretException) as py_ex: - resolve.make_configuration(WithCredentialsConfiguration(), namespaces=("mock",)) + resolve.resolve_configuration(WithCredentialsConfiguration(), namespaces=("mock",)) assert py_ex.value.provider_name == "Mock Provider" assert py_ex.value.key == "-credentials" +def test_do_not_resolve_twice(environment: Any) -> None: + environment["SECRET_VALUE"] = "password" + c = resolve.resolve_configuration(SecretConfiguration()) + assert c.secret_value == "password" + c2 = SecretConfiguration() + c2.secret_value = "other" + c2.__is_resolved__ = True + assert c2.is_resolved() + # will not overwrite with env + c3 = resolve.resolve_configuration(c2) + assert c3.secret_value == "other" + assert c3 is c2 + # make it not resolved + c2.__is_resolved__ = False + c4 = resolve.resolve_configuration(c2) + assert c4.secret_value == "password" + assert c2 is c3 is c4 + # also c is resolved so + c.secret_value = "else" + resolve.resolve_configuration(c).secret_value == "else" + + +def test_do_not_resolve_embedded(environment: Any) -> None: + environment["SECRET__SECRET_VALUE"] = "password" + c = resolve.resolve_configuration(EmbeddedSecretConfiguration()) + assert c.secret.secret_value == "password" + c2 = SecretConfiguration() + c2.secret_value = "other" + c2.__is_resolved__ = True + embed_c = EmbeddedSecretConfiguration() + embed_c.secret = c2 + embed_c2 = resolve.resolve_configuration(embed_c) + assert embed_c2.secret.secret_value == "other" + assert embed_c2.secret is c2 + + +def test_last_resolve_exception(environment: Any) -> None: + # partial will set the ConfigEntryMissingException + c = resolve.resolve_configuration(EmbeddedConfiguration(), accept_partial=True) + assert isinstance(c.__exception__, ConfigEntryMissingException) + # missing keys + c = SecretConfiguration() + with pytest.raises(ConfigEntryMissingException) as py_ex: + resolve.resolve_configuration(c) + assert c.__exception__ is py_ex.value + # but if ran again exception is cleared + environment["SECRET_VALUE"] = "password" + resolve.resolve_configuration(c) + assert c.__exception__ is None + # initial value + c = InstrumentedConfiguration() + with pytest.raises(InvalidInitialValue) as py_ex: + resolve.resolve_configuration(c, initial_value=2137) + assert c.__exception__ is py_ex.value + + def coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: hint = extract_inner_type(hint) return resolve.deserialize_value(key, value, hint) diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py index 7a73a73b46..786a04494f 100644 --- a/tests/common/configuration/test_container.py +++ b/tests/common/configuration/test_container.py @@ -1,26 +1,32 @@ -from typing import Any import pytest +from typing import Any, ClassVar, Literal from dlt.common.configuration import configspec -from dlt.common.configuration.providers.container import ContainerProvider -from dlt.common.configuration.resolve import make_configuration -from dlt.common.configuration.specs import BaseConfiguration -from dlt.common.configuration.container import Container, ContainerInjectableConfiguration -from dlt.common.configuration.exceptions import ContainerInjectableConfigurationMangled, InvalidInitialValue -from dlt.common.configuration.specs.config_providers_configuration import ConfigProvidersListConfiguration +from dlt.common.configuration.providers.container import ContextProvider +from dlt.common.configuration.resolve import resolve_configuration +from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext +from dlt.common.configuration.container import Container +from dlt.common.configuration.exceptions import ContainerInjectableContextMangled, InvalidInitialValue, ContextDefaultCannotBeCreated +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersListContext from tests.utils import preserve_environ from tests.common.configuration.utils import environment @configspec(init=True) -class InjectableTestConfiguration(ContainerInjectableConfiguration): +class InjectableTestContext(ContainerInjectableContext): current_value: str @configspec -class EmbeddedWithInjectableConfiguration(BaseConfiguration): - injected: InjectableTestConfiguration +class EmbeddedWithInjectableContext(BaseConfiguration): + injected: InjectableTestContext + + +@configspec +class NoDefaultInjectableContext(ContainerInjectableContext): + + can_create_default: ClassVar[bool] = False @pytest.fixture() @@ -32,87 +38,110 @@ def container() -> Container: def test_singleton(container: Container) -> None: # keep the old configurations list - container_configurations = container.configurations + container_configurations = container.contexts singleton = Container() # make sure it is the same object assert container is singleton # that holds the same configurations dictionary - assert container_configurations is singleton.configurations + assert container_configurations is singleton.contexts def test_get_default_injectable_config(container: Container) -> None: - pass + injectable = container[InjectableTestContext] + assert injectable.current_value is None + assert isinstance(injectable, InjectableTestContext) + + +def test_raise_on_no_default_value(container: Container) -> None: + with pytest.raises(ContextDefaultCannotBeCreated) as py_ex: + container[NoDefaultInjectableContext] + + # ok when injected + with container.injectable_context(NoDefaultInjectableContext()) as injected: + assert container[NoDefaultInjectableContext] is injected def test_container_injectable_context(container: Container) -> None: - with container.injectable_configuration(InjectableTestConfiguration()) as current_config: + with container.injectable_context(InjectableTestContext()) as current_config: assert current_config.current_value is None current_config.current_value = "TEST" - assert container[InjectableTestConfiguration].current_value == "TEST" - assert container[InjectableTestConfiguration] is current_config + assert container[InjectableTestContext].current_value == "TEST" + assert container[InjectableTestContext] is current_config - assert InjectableTestConfiguration not in container + assert InjectableTestContext not in container def test_container_injectable_context_restore(container: Container) -> None: # this will create InjectableTestConfiguration - original = container[InjectableTestConfiguration] + original = container[InjectableTestContext] original.current_value = "ORIGINAL" - with container.injectable_configuration(InjectableTestConfiguration()) as current_config: + with container.injectable_context(InjectableTestContext()) as current_config: current_config.current_value = "TEST" # nested context is supported - with container.injectable_configuration(InjectableTestConfiguration()) as inner_config: + with container.injectable_context(InjectableTestContext()) as inner_config: assert inner_config.current_value is None - assert container[InjectableTestConfiguration] is inner_config - assert container[InjectableTestConfiguration] is current_config + assert container[InjectableTestContext] is inner_config + assert container[InjectableTestContext] is current_config - assert container[InjectableTestConfiguration] is original - assert container[InjectableTestConfiguration].current_value == "ORIGINAL" + assert container[InjectableTestContext] is original + assert container[InjectableTestContext].current_value == "ORIGINAL" def test_container_injectable_context_mangled(container: Container) -> None: - original = container[InjectableTestConfiguration] + original = container[InjectableTestContext] original.current_value = "ORIGINAL" - injectable = InjectableTestConfiguration() - with pytest.raises(ContainerInjectableConfigurationMangled) as py_ex: - with container.injectable_configuration(injectable) as current_config: + context = InjectableTestContext() + with pytest.raises(ContainerInjectableContextMangled) as py_ex: + with container.injectable_context(context) as current_config: current_config.current_value = "TEST" # overwrite the config in container - container.configurations[InjectableTestConfiguration] = InjectableTestConfiguration() - assert py_ex.value.spec == InjectableTestConfiguration - assert py_ex.value.expected_config == injectable + container.contexts[InjectableTestContext] = InjectableTestContext() + assert py_ex.value.spec == InjectableTestContext + assert py_ex.value.expected_config == context def test_container_provider(container: Container) -> None: - provider = ContainerProvider() - v, k = provider.get_value("n/a", InjectableTestConfiguration) - # provider does not create default value in Container - assert v is None - assert k == str(InjectableTestConfiguration) - assert InjectableTestConfiguration not in container + provider = ContextProvider() + # default value will be created + v, k = provider.get_value("n/a", InjectableTestContext) + assert isinstance(v, InjectableTestContext) + assert k == "InjectableTestContext" + assert InjectableTestContext in container - original = container[InjectableTestConfiguration] - original.current_value = "ORIGINAL" - v, _ = provider.get_value("n/a", InjectableTestConfiguration) + # provider does not create default value in Container + with pytest.raises(ContextDefaultCannotBeCreated): + provider.get_value("n/a", NoDefaultInjectableContext) + assert NoDefaultInjectableContext not in container + + # explicitly create value + original = NoDefaultInjectableContext() + container.contexts[NoDefaultInjectableContext] = original + v, _ = provider.get_value("n/a", NoDefaultInjectableContext) assert v is original # must assert if namespaces are provided with pytest.raises(AssertionError): - provider.get_value("n/a", InjectableTestConfiguration, ("ns1",)) + provider.get_value("n/a", InjectableTestContext, ("ns1",)) + + # type hints that are not classes + l = Literal["a"] + v, k = provider.get_value("n/a", l) + assert v is None + assert k == "typing.Literal['a']" def test_container_provider_embedded_inject(container: Container, environment: Any) -> None: environment["INJECTED"] = "unparsable" - with container.injectable_configuration(InjectableTestConfiguration(current_value="Embed")) as injected: + with container.injectable_context(InjectableTestContext(current_value="Embed")) as injected: # must have top precedence - over the environ provider. environ provider is returning a value that will cannot be parsed # but the container provider has a precedence and the lookup in environ provider will never happen - C = make_configuration(EmbeddedWithInjectableConfiguration()) + C = resolve_configuration(EmbeddedWithInjectableContext()) assert C.injected.current_value == "Embed" assert C.injected is injected # remove first provider - container[ConfigProvidersListConfiguration].providers.pop(0) + container[ConfigProvidersListContext].providers.pop(0) # now environment will provide unparsable value with pytest.raises(InvalidInitialValue): - C = make_configuration(EmbeddedWithInjectableConfiguration()) + C = resolve_configuration(EmbeddedWithInjectableContext()) diff --git a/tests/common/configuration/test_environ_provider.py b/tests/common/configuration/test_environ_provider.py index 36285b0c41..a7392754d1 100644 --- a/tests/common/configuration/test_environ_provider.py +++ b/tests/common/configuration/test_environ_provider.py @@ -32,7 +32,7 @@ def test_resolves_from_environ(environment: Any) -> None: environment["NONECONFIGVAR"] = "Some" C = WrongConfiguration() - resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) assert not C.is_partial() assert C.NoneConfigVar == environment["NONECONFIGVAR"] @@ -42,7 +42,7 @@ def test_resolves_from_environ_with_coercion(environment: Any) -> None: environment["TEST_BOOL"] = 'yes' C = SimpleConfiguration() - resolve._resolve_config_fields(C, namespaces=(), accept_partial=False) + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) assert not C.is_partial() # value will be coerced to bool @@ -51,9 +51,9 @@ def test_resolves_from_environ_with_coercion(environment: Any) -> None: def test_secret(environment: Any) -> None: with pytest.raises(ConfigEntryMissingException): - resolve.make_configuration(SecretConfiguration()) + resolve.resolve_configuration(SecretConfiguration()) environment['SECRET_VALUE'] = "1" - C = resolve.make_configuration(SecretConfiguration()) + C = resolve.resolve_configuration(SecretConfiguration()) assert C.secret_value == "1" # mock the path to point to secret storage # from dlt.common.configuration import config_utils @@ -62,18 +62,18 @@ def test_secret(environment: Any) -> None: try: # must read a secret file environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = resolve.make_configuration(SecretConfiguration()) + C = resolve.resolve_configuration(SecretConfiguration()) assert C.secret_value == "BANANA" # set some weird path, no secret file at all del environment['SECRET_VALUE'] environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" with pytest.raises(ConfigEntryMissingException): - resolve.make_configuration(SecretConfiguration()) + resolve.resolve_configuration(SecretConfiguration()) # set env which is a fallback for secret not as file environment['SECRET_VALUE'] = "1" - C = resolve.make_configuration(SecretConfiguration()) + C = resolve.resolve_configuration(SecretConfiguration()) assert C.secret_value == "1" finally: environ_provider.SECRET_STORAGE_PATH = path @@ -83,7 +83,7 @@ def test_secret_kube_fallback(environment: Any) -> None: path = environ_provider.SECRET_STORAGE_PATH try: environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = resolve.make_configuration(SecretKubeConfiguration()) + C = resolve.resolve_configuration(SecretKubeConfiguration()) # all unix editors will add x10 at the end of file, it will be preserved assert C.secret_kube == "kube\n" # we propagate secrets back to environ and strip the whitespace @@ -95,7 +95,7 @@ def test_secret_kube_fallback(environment: Any) -> None: def test_configuration_files(environment: Any) -> None: # overwrite config file paths environment["CONFIG_FILES_STORAGE_PATH"] = "./tests/common/cases/schemas/ev1/%s" - C = resolve.make_configuration(MockProdConfigurationVar()) + C = resolve.resolve_configuration(MockProdConfigurationVar()) assert C.config_files_storage_path == environment["CONFIG_FILES_STORAGE_PATH"] assert C.has_configuration_file("hasn't") is False assert C.has_configuration_file("event_schema.json") is True diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index 01876e1b9d..d889263e2a 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -6,6 +6,8 @@ from dlt.common.configuration.inject import _spec_from_signature, _get_spec_name_from_f, with_config from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration +from tests.utils import preserve_environ +from tests.common.configuration.utils import environment _DECIMAL_DEFAULT = Decimal("0.01") _SECRET_DEFAULT = TSecretValue("PASS") @@ -108,6 +110,24 @@ def test_inject_with_non_injectable_param() -> None: pass +def test_inject_without_spec() -> None: + pass + + +def test_inject_without_spec_kw_only() -> None: + pass + + +def test_inject_with_auto_namespace(environment: Any) -> None: + environment["PIPE__VALUE"] = "test" + + @with_config(auto_namespace=True) + def f(pipeline_name, value): + assert value == "test" + + f("pipe") + + def test_inject_with_spec() -> None: pass diff --git a/tests/common/configuration/test_namespaces.py b/tests/common/configuration/test_namespaces.py index d79cf9a2f8..e8f462e90c 100644 --- a/tests/common/configuration/test_namespaces.py +++ b/tests/common/configuration/test_namespaces.py @@ -1,12 +1,10 @@ -from unittest import mock import pytest from typing import Any, Optional from dlt.common.configuration.container import Container -from dlt.common.typing import TSecretValue -from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, resolve -from dlt.common.configuration.specs import BaseConfiguration -from dlt.common.configuration.providers import environ as environ_provider +from dlt.common.configuration import configspec, ConfigEntryMissingException, resolve, inject_namespace +from dlt.common.configuration.specs import BaseConfiguration, ConfigNamespacesContext +# from dlt.common.configuration.providers import environ as environ_provider from dlt.common.configuration.exceptions import LookupTrace from tests.utils import preserve_environ @@ -25,7 +23,8 @@ class EmbeddedConfiguration(BaseConfiguration): def test_namespaced_configuration(environment: Any) -> None: with pytest.raises(ConfigEntryMissingException) as exc_val: - resolve.make_configuration(NamespacedConfiguration()) + resolve.resolve_configuration(NamespacedConfiguration()) + assert list(exc_val.value.traces.keys()) == ["password"] assert exc_val.value.spec_name == "NamespacedConfiguration" # check trace @@ -35,15 +34,15 @@ def test_namespaced_configuration(environment: Any) -> None: assert traces[0] == LookupTrace("Environment Variables", ["DLT_TEST"], "DLT_TEST__PASSWORD", None) # init vars work without namespace - C = resolve.make_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) + C = resolve.resolve_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) assert C.password == "PASS" # env var must be prefixed environment["PASSWORD"] = "PASS" with pytest.raises(ConfigEntryMissingException) as exc_val: - resolve.make_configuration(NamespacedConfiguration()) + resolve.resolve_configuration(NamespacedConfiguration()) environment["DLT_TEST__PASSWORD"] = "PASS" - C = resolve.make_configuration(NamespacedConfiguration()) + C = resolve.resolve_configuration(NamespacedConfiguration()) assert C.password == "PASS" @@ -59,16 +58,16 @@ def test_explicit_namespaces(mock_provider: MockProvider) -> None: # via make configuration mock_provider.reset_stats() - resolve.make_configuration(SingleValConfiguration()) + resolve.resolve_configuration(SingleValConfiguration()) assert mock_provider.last_namespace == () mock_provider.reset_stats() - resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1",)) # value is returned only on empty namespace assert mock_provider.last_namespace == () # always start with more precise namespace assert mock_provider.last_namespaces == [("ns1",), ()] mock_provider.reset_stats() - resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1", "ns2")) + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1", "ns2")) assert mock_provider.last_namespaces == [("ns1", "ns2"), ("ns1",), ()] @@ -76,29 +75,29 @@ def test_explicit_namespaces_with_namespaced_config(mock_provider: MockProvider) mock_provider.value = "value" # with namespaced config mock_provider.return_value_on = ("DLT_TEST",) - resolve.make_configuration(NamespacedConfiguration()) + resolve.resolve_configuration(NamespacedConfiguration()) assert mock_provider.last_namespace == ("DLT_TEST",) # namespace from config is mandatory, provider will not be queried with () assert mock_provider.last_namespaces == [("DLT_TEST",)] # namespaced config is always innermost mock_provider.reset_stats() - resolve.make_configuration(NamespacedConfiguration(), namespaces=("ns1",)) + resolve.resolve_configuration(NamespacedConfiguration(), namespaces=("ns1",)) assert mock_provider.last_namespaces == [("ns1", "DLT_TEST"), ("DLT_TEST",)] mock_provider.reset_stats() - resolve.make_configuration(NamespacedConfiguration(), namespaces=("ns1", "ns2")) + resolve.resolve_configuration(NamespacedConfiguration(), namespaces=("ns1", "ns2")) assert mock_provider.last_namespaces == [("ns1", "ns2", "DLT_TEST"), ("ns1", "DLT_TEST"), ("DLT_TEST",)] def test_explicit_namespaces_from_embedded_config(mock_provider: MockProvider) -> None: mock_provider.value = {"sv": "A"} - C = resolve.make_configuration(EmbeddedConfiguration()) + C = resolve.resolve_configuration(EmbeddedConfiguration()) # we mock the dictionary below as the value for all requests assert C.sv_config.sv == '{"sv": "A"}' # following namespaces were used when resolving EmbeddedConfig: () - to resolve sv_config and then: ("sv_config",), () to resolve sv in sv_config assert mock_provider.last_namespaces == [(), ("sv_config",), ()] # embedded namespace inner of explicit mock_provider.reset_stats() - C = resolve.make_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) + C = resolve.resolve_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) assert mock_provider.last_namespaces == [("ns1",), (), ("ns1", "sv_config",), ("ns1",), ()] @@ -106,57 +105,54 @@ def test_injected_namespaces(mock_provider: MockProvider) -> None: container = Container() mock_provider.value = "value" - with container.injectable_configuration(resolve.ConfigNamespacesConfiguration(namespaces=("inj-ns1",))): - resolve.make_configuration(SingleValConfiguration()) + with container.injectable_context(ConfigNamespacesContext(namespaces=("inj-ns1",))): + resolve.resolve_configuration(SingleValConfiguration()) assert mock_provider.last_namespaces == [("inj-ns1",), ()] mock_provider.reset_stats() - # explicit namespace inner of injected - resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) - assert mock_provider.last_namespaces == [("inj-ns1", "ns1"), ("inj-ns1",), ()] + # explicit namespace preempts injected namespace + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("ns1",), ()] # namespaced config inner of injected mock_provider.reset_stats() mock_provider.return_value_on = ("DLT_TEST",) - resolve.make_configuration(NamespacedConfiguration(), namespaces=("ns1",)) - assert mock_provider.last_namespaces == [("inj-ns1", "ns1", "DLT_TEST"), ("inj-ns1", "DLT_TEST"), ("DLT_TEST",)] - # explicit namespace inner of ns coming from embedded config + resolve.resolve_configuration(NamespacedConfiguration()) + assert mock_provider.last_namespaces == [("inj-ns1", "DLT_TEST"), ("DLT_TEST",)] + # injected namespace inner of ns coming from embedded config mock_provider.reset_stats() mock_provider.return_value_on = () mock_provider.value = {"sv": "A"} - resolve.make_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) - # first we look for sv_config -> ("inj-ns1", "ns1"), ("inj-ns1",), () then we look for sv - assert mock_provider.last_namespaces == [("inj-ns1", "ns1"), ("inj-ns1",), (), ("inj-ns1", "ns1", "sv_config"), ("inj-ns1", "ns1"), ("inj-ns1",), ()] + resolve.resolve_configuration(EmbeddedConfiguration()) + # first we look for sv_config -> ("inj-ns1",), () then we look for sv + assert mock_provider.last_namespaces == [("inj-ns1", ), (), ("inj-ns1", "sv_config"), ("inj-ns1",), ()] # multiple injected namespaces - with container.injectable_configuration(resolve.ConfigNamespacesConfiguration(namespaces=("inj-ns1", "inj-ns2"))): + with container.injectable_context(ConfigNamespacesContext(namespaces=("inj-ns1", "inj-ns2"))): mock_provider.reset_stats() - resolve.make_configuration(SingleValConfiguration()) + resolve.resolve_configuration(SingleValConfiguration()) assert mock_provider.last_namespaces == [("inj-ns1", "inj-ns2"), ("inj-ns1",), ()] mock_provider.reset_stats() - # explicit namespace inner of injected - resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) - assert mock_provider.last_namespaces == [("inj-ns1", "inj-ns2", "ns1"), ("inj-ns1", "inj-ns2"), ("inj-ns1",), ()] -def test_namespace_from_pipeline_name(mock_provider: MockProvider) -> None: +def test_namespace_with_pipeline_name(mock_provider: MockProvider) -> None: # AXIES__DESTINATION__STORAGE_CREDENTIALS__PRIVATE_KEY, DESTINATION__STORAGE_CREDENTIALS__PRIVATE_KEY, DESTINATION__PRIVATE_KEY, GCP__PRIVATE_KEY # if pipeline name is present, keys will be looked up twice: with pipeline as top level namespace and without it container = Container() mock_provider.value = "value" - with container.injectable_configuration(resolve.ConfigNamespacesConfiguration(pipeline_name="PIPE")): + with container.injectable_context(ConfigNamespacesContext(pipeline_name="PIPE")): mock_provider.return_value_on = () - resolve.make_configuration(SingleValConfiguration()) + resolve.resolve_configuration(SingleValConfiguration()) assert mock_provider.last_namespaces == [("PIPE",), ()] mock_provider.reset_stats() - resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1",)) # PIPE namespace is exhausted then another lookup without PIPE assert mock_provider.last_namespaces == [("PIPE", "ns1"), ("PIPE",), ("ns1",), ()] mock_provider.return_value_on = ("PIPE", ) mock_provider.reset_stats() - resolve.make_configuration(SingleValConfiguration(), namespaces=("ns1",)) + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1",)) assert mock_provider.last_namespaces == [("PIPE", "ns1"), ("PIPE",)] # with both pipe and config namespaces are always present in lookup @@ -165,15 +161,52 @@ def test_namespace_from_pipeline_name(mock_provider: MockProvider) -> None: mock_provider.reset_stats() # () will never be searched with pytest.raises(ConfigEntryMissingException): - resolve.make_configuration(NamespacedConfiguration()) + resolve.resolve_configuration(NamespacedConfiguration()) mock_provider.return_value_on = ("DLT_TEST",) mock_provider.reset_stats() - resolve.make_configuration(NamespacedConfiguration()) + resolve.resolve_configuration(NamespacedConfiguration()) assert mock_provider.last_namespaces == [("PIPE", "DLT_TEST"), ("DLT_TEST",)] # with pipeline and injected namespaces - with container.injectable_configuration(resolve.ConfigNamespacesConfiguration(pipeline_name="PIPE", namespaces=("inj-ns1",))): + with container.injectable_context(ConfigNamespacesContext(pipeline_name="PIPE", namespaces=("inj-ns1",))): mock_provider.return_value_on = () mock_provider.reset_stats() - resolve.make_configuration(SingleValConfiguration()) + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("PIPE", "inj-ns1"), ("PIPE",), ("inj-ns1",), ()] + + +# def test_namespaces_with_duplicate(mock_provider: MockProvider) -> None: +# container = Container() +# mock_provider.value = "value" + +# with container.injectable_context(ConfigNamespacesContext(pipeline_name="DLT_TEST", namespaces=("DLT_TEST", "DLT_TEST"))): +# mock_provider.return_value_on = ("DLT_TEST",) +# resolve.resolve_configuration(NamespacedConfiguration(), namespaces=("DLT_TEST", "DLT_TEST")) +# # no duplicates are removed, duplicates are misconfiguration +# # note: use dict.fromkeys to create ordered sets from lists if we ever want to remove duplicates +# # the lookup tuples are create as follows: +# # 1. (pipeline name, deduplicated namespaces, config namespace) +# # 2. (deduplicated namespaces, config namespace) +# # 3. (pipeline name, config namespace) +# # 4. (config namespace) +# assert mock_provider.last_namespaces == [("DLT_TEST", "DLT_TEST", "DLT_TEST", "DLT_TEST"), ("DLT_TEST", "DLT_TEST", "DLT_TEST"), ("DLT_TEST", "DLT_TEST"), ("DLT_TEST", "DLT_TEST"), ("DLT_TEST",)] + + +def test_inject_namespace(mock_provider: MockProvider) -> None: + mock_provider.value = "value" + + with inject_namespace(ConfigNamespacesContext(pipeline_name="PIPE", namespaces=("inj-ns1",))): + resolve.resolve_configuration(SingleValConfiguration()) assert mock_provider.last_namespaces == [("PIPE", "inj-ns1"), ("PIPE",), ("inj-ns1",), ()] + + # inject with merge previous + with inject_namespace(ConfigNamespacesContext(namespaces=("inj-ns2",))): + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("PIPE", "inj-ns2"), ("PIPE",), ("inj-ns2",), ()] + + # inject without merge + mock_provider.reset_stats() + with inject_namespace(ConfigNamespacesContext(), merge_existing=False): + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [()] diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py index 430b971491..02d028b1be 100644 --- a/tests/common/configuration/utils.py +++ b/tests/common/configuration/utils.py @@ -2,7 +2,7 @@ from os import environ from typing import Any, List, Optional, Tuple, Type from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_providers_configuration import ConfigProvidersListConfiguration +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersListContext from dlt.common.typing import TSecretValue from dlt.common.configuration import configspec @@ -48,7 +48,7 @@ def environment() -> Any: @pytest.fixture(scope="function") def mock_provider() -> "MockProvider": container = Container() - with container.injectable_configuration(ConfigProvidersListConfiguration()) as providers: + with container.injectable_context(ConfigProvidersListContext()) as providers: # replace all providers with MockProvider that does not support secrets mock_provider = MockProvider() providers.providers = [mock_provider] From 71fcddc2ed822c5261e71a16a8aab04c59cafd75 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 18 Oct 2022 13:30:18 +0200 Subject: [PATCH 38/66] implements new instantiation mechanics for destinations --- dlt/load/bigquery/__init__.py | 41 +++++++ dlt/load/bigquery/{client.py => bigquery.py} | 102 +++++++-------- dlt/load/bigquery/configuration.py | 40 +++--- dlt/load/client_base.py | 116 ++++-------------- dlt/load/client_base_impl.py | 81 ++++++++++++ dlt/load/configuration.py | 29 +++-- dlt/load/dummy/__init__.py | 43 +++++++ dlt/load/dummy/configuration.py | 15 +-- dlt/load/dummy/{client.py => dummy.py} | 58 ++++----- dlt/load/exceptions.py | 4 +- dlt/load/load.py | 85 +++++-------- dlt/load/redshift/__init__.py | 41 +++++++ dlt/load/redshift/configuration.py | 20 +-- dlt/load/redshift/{client.py => redshift.py} | 62 ++++------ dlt/load/typing.py | 18 +-- tests/load/bigquery/test_bigquery_client.py | 2 +- .../bigquery/test_bigquery_table_builder.py | 12 +- tests/load/redshift/test_redshift_client.py | 2 +- .../redshift/test_redshift_table_builder.py | 10 +- tests/load/test_client.py | 13 +- tests/load/test_dummy_client.py | 83 ++++++------- tests/load/utils.py | 36 ++++-- 22 files changed, 478 insertions(+), 435 deletions(-) rename dlt/load/bigquery/{client.py => bigquery.py} (83%) create mode 100644 dlt/load/client_base_impl.py rename dlt/load/dummy/{client.py => dummy.py} (68%) rename dlt/load/redshift/{client.py => redshift.py} (87%) diff --git a/dlt/load/bigquery/__init__.py b/dlt/load/bigquery/__init__.py index e69de29bb2..e004fe7afd 100644 --- a/dlt/load/bigquery/__init__.py +++ b/dlt/load/bigquery/__init__.py @@ -0,0 +1,41 @@ +from typing import Type + +from dlt.common.schema.schema import Schema +from dlt.common.typing import ConfigValue +from dlt.common.configuration import with_config +from dlt.common.configuration.specs import DestinationCapabilitiesContext + +from dlt.load.client_base import JobClientBase +from dlt.load.configuration import DestinationClientConfiguration +from dlt.load.bigquery.configuration import BigQueryClientConfiguration + + +@with_config(spec=BigQueryClientConfiguration, namespaces=("destination", "bigquery",)) +def _configure(config: BigQueryClientConfiguration = ConfigValue) -> BigQueryClientConfiguration: + return config + + +def capabilities() -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.update({ + "preferred_loader_file_format": "jsonl", + "supported_loader_file_formats": ["jsonl"], + "max_identifier_length": 1024, + "max_column_length": 300, + "max_query_length": 1024 * 1024, + "is_max_query_length_in_bytes": False, + "max_text_data_type_length": 10 * 1024 * 1024, + "is_max_text_data_type_length_in_bytes": True + }) + return caps + + +def client(schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> JobClientBase: + # import client when creating instance so capabilities and config specs can be accessed without dependencies installed + from dlt.load.bigquery.bigquery import BigQueryClient + + return BigQueryClient(schema, _configure(initial_config)) # type: ignore + + +def spec() -> Type[DestinationClientConfiguration]: + return BigQueryClientConfiguration \ No newline at end of file diff --git a/dlt/load/bigquery/client.py b/dlt/load/bigquery/bigquery.py similarity index 83% rename from dlt/load/bigquery/client.py rename to dlt/load/bigquery/bigquery.py index 08a33be155..60996a0591 100644 --- a/dlt/load/bigquery/client.py +++ b/dlt/load/bigquery/bigquery.py @@ -1,6 +1,7 @@ from pathlib import Path from contextlib import contextmanager from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple +from dlt.common.storages.file_storage import FileStorage import google.cloud.bigquery as bigquery # noqa: I250 from google.cloud.bigquery.dbapi import Connection as DbApiConnection from google.cloud import exceptions as gcp_exceptions @@ -12,15 +13,17 @@ from dlt.common.typing import StrAny from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.configuration.specs import GcpClientCredentials +from dlt.common.configuration.specs import GcpClientCredentials, DestinationCapabilitiesContext from dlt.common.data_writers import escape_bigquery_identifier from dlt.common.schema import TColumnSchema, TDataType, Schema, TTableSchemaColumns -from dlt.load.typing import LoadJobStatus, DBCursor, TLoaderCapabilities -from dlt.load.client_base import JobClientBase, SqlClientBase, SqlJobClientBase, LoadJob +from dlt.load.typing import TLoadJobStatus, DBCursor +from dlt.load.client_base import SqlClientBase, LoadJob +from dlt.load.client_base_impl import SqlJobClientBase from dlt.load.exceptions import LoadClientSchemaWillNotUpdate, LoadJobNotExistsException, LoadJobServerTerminalException, LoadUnknownTableException -from dlt.load.bigquery.configuration import BigQueryClientConfiguration, configuration +from dlt.load.bigquery import capabilities +from dlt.load.bigquery.configuration import BigQueryClientConfiguration SCT_TO_BQT: Dict[TDataType, str] = { @@ -52,21 +55,21 @@ class BigQuerySqlClient(SqlClientBase[bigquery.Client]): - def __init__(self, default_dataset_name: str, CREDENTIALS: GcpClientCredentials) -> None: + def __init__(self, default_dataset_name: str, credentials: GcpClientCredentials) -> None: self._client: bigquery.Client = None - self.C = CREDENTIALS + self.credentials = credentials super().__init__(default_dataset_name) - self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(CREDENTIALS.retry_deadline) + self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(credentials.retry_deadline) self.default_query = bigquery.QueryJobConfig(default_dataset=self.fully_qualified_dataset_name()) def open_connection(self) -> None: # use default credentials if partial config - if not self.C.is_resolved(): + if not self.credentials.is_resolved(): credentials = None else: - credentials = service_account.Credentials.from_service_account_info(self.C.to_native_representation()) - self._client = bigquery.Client(self.C.project_id, credentials=credentials, location=self.C.location) + credentials = service_account.Credentials.from_service_account_info(self.credentials.to_native_representation()) + self._client = bigquery.Client(self.credentials.project_id, credentials=credentials, location=self.credentials.location) def close_connection(self) -> None: if self._client: @@ -79,7 +82,7 @@ def native_connection(self) -> bigquery.Client: def has_dataset(self) -> bool: try: - self._client.get_dataset(self.fully_qualified_dataset_name(), retry=self.default_retry, timeout=self.C.http_timeout) + self._client.get_dataset(self.fully_qualified_dataset_name(), retry=self.default_retry, timeout=self.credentials.http_timeout) return True except gcp_exceptions.NotFound: return False @@ -89,7 +92,7 @@ def create_dataset(self) -> None: self.fully_qualified_dataset_name(), exists_ok=False, retry=self.default_retry, - timeout=self.C.http_timeout + timeout=self.credentials.http_timeout ) def drop_dataset(self) -> None: @@ -98,7 +101,7 @@ def drop_dataset(self) -> None: not_found_ok=True, delete_contents=True, retry=self.default_retry, - timeout=self.C.http_timeout + timeout=self.credentials.http_timeout ) def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: @@ -106,7 +109,7 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen def_kwargs = { "job_config": self.default_query, "job_retry": self.default_retry, - "timeout": self.C.http_timeout + "timeout": self.credentials.http_timeout } kwargs = {**def_kwargs, **(kwargs or {})} results = self._client.query(sql, *args, **kwargs).result() @@ -135,19 +138,19 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[D conn.close() def fully_qualified_dataset_name(self) -> str: - return f"{self.C.project_id}.{self.default_dataset_name}" + return f"{self.credentials.project_id}.{self.default_dataset_name}" class BigQueryLoadJob(LoadJob): - def __init__(self, file_name: str, bq_load_job: bigquery.LoadJob, CONFIG: GcpClientCredentials) -> None: + def __init__(self, file_name: str, bq_load_job: bigquery.LoadJob, credentials: GcpClientCredentials) -> None: self.bq_load_job = bq_load_job - self.C = CONFIG - self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(CONFIG.retry_deadline) + self.credentials = credentials + self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(credentials.retry_deadline) super().__init__(file_name) - def status(self) -> LoadJobStatus: + def status(self) -> TLoadJobStatus: # check server if done - done = self.bq_load_job.done(retry=self.default_retry, timeout=self.C.http_timeout) + done = self.bq_load_job.done(retry=self.default_retry, timeout=self.credentials.http_timeout) if done: # rows processed if self.bq_load_job.output_rows is not None and self.bq_load_job.error_result is None: @@ -183,15 +186,16 @@ def exception(self) -> str: class BigQueryClient(SqlJobClientBase): - CONFIG: BigQueryClientConfiguration = None - CREDENTIALS: GcpClientCredentials = None + # CONFIG: BigQueryClientConfiguration = None + # CREDENTIALS: GcpClientCredentials = None - def __init__(self, schema: Schema) -> None: + def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: sql_client = BigQuerySqlClient( - schema.normalize_make_dataset_name(self.CONFIG.default_dataset, self.CONFIG.default_schema_name, schema.name), - self.CREDENTIALS + schema.normalize_make_dataset_name(config.dataset_name, config.default_schema_name, schema.name), + config.credentials ) - super().__init__(schema, sql_client) + super().__init__(schema, config, sql_client) + self.config: BigQueryClientConfiguration = config self.sql_client: BigQuerySqlClient = sql_client def initialize_storage(self) -> None: @@ -201,9 +205,9 @@ def initialize_storage(self) -> None: def restore_file_load(self, file_path: str) -> LoadJob: try: return BigQueryLoadJob( - JobClientBase.get_file_name_from_file_path(file_path), + FileStorage.get_file_name_from_file_path(file_path), self._retrieve_load_job(file_path), - self.CREDENTIALS + self.config.credentials #self.sql_client.native_connection() ) except api_core_exceptions.GoogleAPICallError as gace: @@ -218,9 +222,9 @@ def restore_file_load(self, file_path: str) -> LoadJob: def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: try: return BigQueryLoadJob( - JobClientBase.get_file_name_from_file_path(file_path), + FileStorage.get_file_name_from_file_path(file_path), self._create_load_job(table["name"], table["write_disposition"], file_path), - self.CREDENTIALS + self.config.credentials ) except api_core_exceptions.GoogleAPICallError as gace: reason = self._get_reason_from_errors(gace) @@ -296,7 +300,9 @@ def _get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns schema_table: TTableSchemaColumns = {} try: table = self.sql_client.native_connection.get_table( - self.sql_client.make_qualified_table_name(table_name), retry=self.sql_client.default_retry, timeout=self.CREDENTIALS.http_timeout + self.sql_client.make_qualified_table_name(table_name), + retry=self.sql_client.default_retry, + timeout=self.config.credentials.http_timeout ) partition_field = table.time_partitioning.field if table.time_partitioning else None for c in table.schema: @@ -329,12 +335,13 @@ def _create_load_job(self, table_name: str, write_disposition: TWriteDisposition ) with open(file_path, "rb") as f: - return self.sql_client.native_connection.load_table_from_file(f, - self.sql_client.make_qualified_table_name(table_name), - job_id=job_id, - job_config=job_config, - timeout=self.CREDENTIALS.file_upload_timeout - ) + return self.sql_client.native_connection.load_table_from_file( + f, + self.sql_client.make_qualified_table_name(table_name), + job_id=job_id, + job_config=job_config, + timeout=self.config.credentials.file_upload_timeout + ) def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: job_id = BigQueryClient._get_job_id_from_file_path(file_path) @@ -367,22 +374,5 @@ def _bq_t_to_sc_t(bq_t: str, precision: Optional[int], scale: Optional[int]) -> return BQT_TO_SCT.get(bq_t, "text") @classmethod - def capabilities(cls) -> TLoaderCapabilities: - return { - "preferred_loader_file_format": "jsonl", - "supported_loader_file_formats": ["jsonl"], - "max_identifier_length": 1024, - "max_column_length": 300, - "max_query_length": 1024 * 1024, - "is_max_query_length_in_bytes": False, - "max_text_data_type_length": 10 * 1024 * 1024, - "is_max_text_data_type_length_in_bytes": True - } - - @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[BigQueryClientConfiguration, GcpClientCredentials]: - cls.CONFIG, cls.CREDENTIALS = configuration(initial_values=initial_values) - return cls.CONFIG, cls.CREDENTIALS - - -CLIENT = BigQueryClient + def capabilities(cls) -> DestinationCapabilitiesContext: + return capabilities() diff --git a/dlt/load/bigquery/configuration.py b/dlt/load/bigquery/configuration.py index cd314b1382..ab66d48e45 100644 --- a/dlt/load/bigquery/configuration.py +++ b/dlt/load/bigquery/configuration.py @@ -1,39 +1,27 @@ -from typing import Tuple +from typing import Optional from google.auth import default as default_credentials from google.auth.exceptions import DefaultCredentialsError -from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration, configspec +from dlt.common.configuration import configspec from dlt.common.configuration.specs import GcpClientCredentials from dlt.common.configuration.exceptions import ConfigEntryMissingException -from dlt.load.configuration import LoaderClientDwhConfiguration +from dlt.load.configuration import DestinationClientDwhConfiguration -@configspec -class BigQueryClientConfiguration(LoaderClientDwhConfiguration): - client_type: str = "bigquery" +@configspec(init=True) +class BigQueryClientConfiguration(DestinationClientDwhConfiguration): + destination_name: str = "bigquery" + credentials: Optional[GcpClientCredentials] = None - -def configuration(initial_values: StrAny = None) -> Tuple[BigQueryClientConfiguration, GcpClientCredentials]: - - def maybe_partial_credentials() -> GcpClientCredentials: - try: - return make_configuration(GcpClientCredentials(), initial_value=initial_values) - except ConfigEntryMissingException as cfex: - # if config is missing check if credentials can be obtained from defaults + def check_integrity(self) -> None: + if not self.credentials.is_resolved(): + # if config is missing check if credentials can be obtained from defaults try: _, project_id = default_credentials() - # if so then return partial so we can access timeouts - C_PARTIAL = make_configuration(GcpClientCredentials(), initial_value=initial_values, accept_partial = True) # set the project id - it needs to be known by the client - C_PARTIAL.project_id = C_PARTIAL.project_id or project_id - return C_PARTIAL + self.credentials.project_id = self.credentials.project_id or project_id except DefaultCredentialsError: - raise cfex - - return ( - make_configuration(BigQueryClientConfiguration(), initial_value=initial_values), - # allow partial credentials so the client can fallback to default credentials - maybe_partial_credentials() - ) + print("DefaultCredentialsError") + # re-raise preventing exception + raise self.credentials.__exception__ diff --git a/dlt/load/client_base.py b/dlt/load/client_base.py index 7ed63c540b..b694128432 100644 --- a/dlt/load/client_base.py +++ b/dlt/load/client_base.py @@ -1,17 +1,16 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from types import TracebackType -from typing import Any, ContextManager, Generic, Iterator, List, Optional, Sequence, Tuple, Type, AnyStr +from typing import Any, ContextManager, Generic, Iterator, Optional, Sequence, Tuple, Type, AnyStr, Protocol from pathlib import Path -from dlt.common import pendulum, logger -from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns +from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchema -from dlt.common.typing import StrAny +from dlt.common.typing import ConfigValue +from dlt.common.configuration.specs import DestinationCapabilitiesContext -from dlt.load.typing import LoadJobStatus, TNativeConn, TLoaderCapabilities, DBCursor -from dlt.load.exceptions import LoadClientSchemaVersionCorrupted +from dlt.load.configuration import DestinationClientConfiguration +from dlt.load.typing import TLoadJobStatus, TNativeConn, DBCursor class LoadJob: @@ -32,7 +31,7 @@ def __init__(self, file_name: str) -> None: self._file_name = file_name @abstractmethod - def status(self) -> LoadJobStatus: + def status(self) -> TLoadJobStatus: pass @abstractmethod @@ -44,25 +43,10 @@ def exception(self) -> str: pass -class LoadEmptyJob(LoadJob): - def __init__(self, file_name: str, status: LoadJobStatus, exception: str = None) -> None: - self._status = status - self._exception = exception - super().__init__(file_name) - - def status(self) -> LoadJobStatus: - return self._status - - def file_name(self) -> str: - return self._file_name - - def exception(self) -> str: - return self._exception - - class JobClientBase(ABC): - def __init__(self, schema: Schema) -> None: + def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> None: self.schema = schema + self.config = config @abstractmethod def initialize_storage(self) -> None: @@ -92,27 +76,26 @@ def __enter__(self) -> "JobClientBase": def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: pass - @staticmethod - def get_file_name_from_file_path(file_path: str) -> str: - return Path(file_path).name - - @staticmethod - def make_job_with_status(file_path: str, status: LoadJobStatus, message: str = None) -> LoadJob: - return LoadEmptyJob(JobClientBase.get_file_name_from_file_path(file_path), status, exception=message) - - @staticmethod - def make_absolute_path(file_path: str) -> str: - return str(Path(file_path).absolute()) - @classmethod @abstractmethod - def capabilities(cls) -> TLoaderCapabilities: + def capabilities(cls) -> DestinationCapabilitiesContext: pass - @classmethod - @abstractmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[BaseConfiguration, CredentialsConfiguration]: - pass + # @classmethod + # @abstractmethod + # def configure(cls, initial_values: StrAny = None) -> Tuple[BaseConfiguration, CredentialsConfiguration]: + # pass + + +class DestinationReference(Protocol): + def capabilities(self) -> DestinationCapabilitiesContext: + ... + + def client(self, schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> "JobClientBase": + ... + + def spec(self) -> Type[DestinationClientConfiguration]: + ... class SqlClientBase(ABC, Generic[TNativeConn]): @@ -176,52 +159,3 @@ def with_alternative_dataset_name(self, dataset_name: str) -> Iterator["SqlClien finally: # restore previous dataset name self.default_dataset_name = current_dataset_name - - -class SqlJobClientBase(JobClientBase): - def __init__(self, schema: Schema, sql_client: SqlClientBase[TNativeConn]) -> None: - super().__init__(schema) - self.sql_client = sql_client - - def update_storage_schema(self) -> None: - storage_version = self._get_schema_version_from_storage() - if storage_version < self.schema.stored_version: - for sql in self._build_schema_update_sql(): - self.sql_client.execute_sql(sql) - self._update_schema_version(self.schema.stored_version) - - def complete_load(self, load_id: str) -> None: - name = self.sql_client.make_qualified_table_name(Schema.LOADS_TABLE_NAME) - now_ts = str(pendulum.now()) - self.sql_client.execute_sql(f"INSERT INTO {name}(load_id, status, inserted_at) VALUES('{load_id}', 0, '{now_ts}');") - - def __enter__(self) -> "SqlJobClientBase": - self.sql_client.open_connection() - return self - - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: - self.sql_client.close_connection() - - @abstractmethod - def _build_schema_update_sql(self) -> List[str]: - pass - - def _create_table_update(self, table_name: str, storage_table: TTableSchemaColumns) -> Sequence[TColumnSchema]: - # compare table with stored schema and produce delta - updates = self.schema.get_new_columns(table_name, storage_table) - logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") - return updates - - def _get_schema_version_from_storage(self) -> int: - name = self.sql_client.make_qualified_table_name(Schema.VERSION_TABLE_NAME) - rows = self.sql_client.execute_sql(f"SELECT {Schema.VERSION_COLUMN_NAME} FROM {name} ORDER BY inserted_at DESC LIMIT 1;") - if len(rows) > 1: - raise LoadClientSchemaVersionCorrupted(self.sql_client.fully_qualified_dataset_name()) - if len(rows) == 0: - return 0 - return int(rows[0][0]) - - def _update_schema_version(self, new_version: int) -> None: - now_ts = str(pendulum.now()) - name = self.sql_client.make_qualified_table_name(Schema.VERSION_TABLE_NAME) - self.sql_client.execute_sql(f"INSERT INTO {name}({Schema.VERSION_COLUMN_NAME}, engine_version, inserted_at) VALUES ({new_version}, {Schema.ENGINE_VERSION}, '{now_ts}');") diff --git a/dlt/load/client_base_impl.py b/dlt/load/client_base_impl.py new file mode 100644 index 0000000000..894e8fb605 --- /dev/null +++ b/dlt/load/client_base_impl.py @@ -0,0 +1,81 @@ +from abc import abstractmethod +from types import TracebackType +from typing import List, Sequence, Type + +from dlt.common import pendulum, logger +from dlt.common.storages import FileStorage +from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns + +from dlt.load.typing import TLoadJobStatus, TNativeConn +from dlt.load.client_base import LoadJob, JobClientBase, SqlClientBase +from dlt.load.configuration import DestinationClientConfiguration +from dlt.load.exceptions import LoadClientSchemaVersionCorrupted + + +class LoadEmptyJob(LoadJob): + def __init__(self, file_name: str, status: TLoadJobStatus, exception: str = None) -> None: + self._status = status + self._exception = exception + super().__init__(file_name) + + @classmethod + def from_file_path(cls, file_path: str, status: TLoadJobStatus, message: str = None) -> "LoadEmptyJob": + return cls(FileStorage.get_file_name_from_file_path(file_path), status, exception=message) + + def status(self) -> TLoadJobStatus: + return self._status + + def file_name(self) -> str: + return self._file_name + + def exception(self) -> str: + return self._exception + + +class SqlJobClientBase(JobClientBase): + def __init__(self, schema: Schema, config: DestinationClientConfiguration, sql_client: SqlClientBase[TNativeConn]) -> None: + super().__init__(schema, config) + self.sql_client = sql_client + + def update_storage_schema(self) -> None: + storage_version = self._get_schema_version_from_storage() + if storage_version < self.schema.stored_version: + for sql in self._build_schema_update_sql(): + self.sql_client.execute_sql(sql) + self._update_schema_version(self.schema.stored_version) + + def complete_load(self, load_id: str) -> None: + name = self.sql_client.make_qualified_table_name(Schema.LOADS_TABLE_NAME) + now_ts = str(pendulum.now()) + self.sql_client.execute_sql(f"INSERT INTO {name}(load_id, status, inserted_at) VALUES('{load_id}', 0, '{now_ts}');") + + def __enter__(self) -> "SqlJobClientBase": + self.sql_client.open_connection() + return self + + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + self.sql_client.close_connection() + + @abstractmethod + def _build_schema_update_sql(self) -> List[str]: + pass + + def _create_table_update(self, table_name: str, storage_table: TTableSchemaColumns) -> Sequence[TColumnSchema]: + # compare table with stored schema and produce delta + updates = self.schema.get_new_columns(table_name, storage_table) + logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") + return updates + + def _get_schema_version_from_storage(self) -> int: + name = self.sql_client.make_qualified_table_name(Schema.VERSION_TABLE_NAME) + rows = self.sql_client.execute_sql(f"SELECT {Schema.VERSION_COLUMN_NAME} FROM {name} ORDER BY inserted_at DESC LIMIT 1;") + if len(rows) > 1: + raise LoadClientSchemaVersionCorrupted(self.sql_client.fully_qualified_dataset_name()) + if len(rows) == 0: + return 0 + return int(rows[0][0]) + + def _update_schema_version(self, new_version: int) -> None: + now_ts = str(pendulum.now()) + name = self.sql_client.make_qualified_table_name(Schema.VERSION_TABLE_NAME) + self.sql_client.execute_sql(f"INSERT INTO {name}({Schema.VERSION_COLUMN_NAME}, engine_version, inserted_at) VALUES ({new_version}, {Schema.ENGINE_VERSION}, '{now_ts}');") diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index d7e9ef16bd..bdf79eb604 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -1,25 +1,24 @@ from typing import Optional -from dlt.common.typing import StrAny -from dlt.common.configuration import configspec, make_configuration -from dlt.common.configuration.specs import BaseConfiguration, PoolRunnerConfiguration, LoadVolumeConfiguration, TPoolType +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import BaseConfiguration, PoolRunnerConfiguration, CredentialsConfiguration, TPoolType +from dlt.common.configuration.specs.load_volume_configuration import LoadVolumeConfiguration -@configspec -class LoaderClientConfiguration(BaseConfiguration): - client_type: str = None # which destination to load data to +@configspec(init=True) +class DestinationClientConfiguration(BaseConfiguration): + destination_name: str = None # which destination to load data to + credentials: Optional[CredentialsConfiguration] -@configspec -class LoaderClientDwhConfiguration(LoaderClientConfiguration): - default_dataset: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix + +@configspec(init=True) +class DestinationClientDwhConfiguration(DestinationClientConfiguration): + dataset_name: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix default_schema_name: Optional[str] = None # name of default schema to be used to name effective dataset to load data to -@configspec -class LoaderConfiguration(PoolRunnerConfiguration, LoadVolumeConfiguration, LoaderClientConfiguration): +@configspec(init=True) +class LoaderConfiguration(PoolRunnerConfiguration): workers: int = 20 # how many parallel loads can be executed pool_type: TPoolType = "thread" # mostly i/o (upload) so may be thread pool - - -def configuration(initial_values: StrAny = None) -> LoaderConfiguration: - return make_configuration(LoaderConfiguration(), initial_value=initial_values) + load_storage_config: LoadVolumeConfiguration = None diff --git a/dlt/load/dummy/__init__.py b/dlt/load/dummy/__init__.py index e69de29bb2..2ac57d272a 100644 --- a/dlt/load/dummy/__init__.py +++ b/dlt/load/dummy/__init__.py @@ -0,0 +1,43 @@ +from typing import Type + +from dlt.common.schema.schema import Schema +from dlt.common.typing import ConfigValue +from dlt.common.configuration import with_config +from dlt.common.configuration.specs import DestinationCapabilitiesContext + +from dlt.load.client_base import JobClientBase +from dlt.load.configuration import DestinationClientConfiguration +from dlt.load.dummy.configuration import DummyClientConfiguration + + +@with_config(spec=DummyClientConfiguration, namespaces=("destination", "dummy",)) +def _configure(config: DummyClientConfiguration = ConfigValue) -> DummyClientConfiguration: + print(dict(config)) + return config + + +def capabilities() -> DestinationCapabilitiesContext: + config = _configure() + caps = DestinationCapabilitiesContext() + caps.update({ + "preferred_loader_file_format": config.loader_file_format, + "supported_loader_file_formats": [config.loader_file_format], + "max_identifier_length": 127, + "max_column_length": 127, + "max_query_length": 8 * 1024 * 1024, + "is_max_query_length_in_bytes": True, + "max_text_data_type_length": 65535, + "is_max_text_data_type_length_in_bytes": True + }) + return caps + + +def client(schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> JobClientBase: + # import client when creating instance so capabilities and config specs can be accessed without dependencies installed + from dlt.load.dummy.dummy import DummyClient + + return DummyClient(schema, _configure(initial_config)) # type: ignore + + +def spec() -> Type[DestinationClientConfiguration]: + return DummyClientConfiguration diff --git a/dlt/load/dummy/configuration.py b/dlt/load/dummy/configuration.py index 93d9907258..dc301ecb73 100644 --- a/dlt/load/dummy/configuration.py +++ b/dlt/load/dummy/configuration.py @@ -1,19 +1,14 @@ -from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration, configspec +from dlt.common.configuration import configspec from dlt.common.data_writers import TLoaderFileFormat -from dlt.load.configuration import LoaderClientConfiguration +from dlt.load.configuration import DestinationClientConfiguration -@configspec -class DummyClientConfiguration(LoaderClientConfiguration): - client_type: str = "dummy" +@configspec(init=True) +class DummyClientConfiguration(DestinationClientConfiguration): + destination_name: str = "dummy" loader_file_format: TLoaderFileFormat = "jsonl" fail_prob: float = 0.0 retry_prob: float = 0.0 completed_prob: float = 0.0 timeout: float = 10.0 - - -def configuration(initial_values: StrAny = None) -> DummyClientConfiguration: - return make_configuration(DummyClientConfiguration(), initial_value=initial_values) diff --git a/dlt/load/dummy/client.py b/dlt/load/dummy/dummy.py similarity index 68% rename from dlt/load/dummy/client.py rename to dlt/load/dummy/dummy.py index 3b4888b77d..e3ca263cbb 100644 --- a/dlt/load/dummy/client.py +++ b/dlt/load/dummy/dummy.py @@ -1,19 +1,20 @@ import random from types import TracebackType -from typing import Dict, Tuple, Type +from typing import Dict, Type from dlt.common import pendulum from dlt.common.schema import Schema +from dlt.common.storages import FileStorage from dlt.common.schema.typing import TTableSchema -from dlt.common.configuration.specs import CredentialsConfiguration -from dlt.common.typing import StrAny +from dlt.common.configuration.specs import DestinationCapabilitiesContext -from dlt.load.client_base import JobClientBase, LoadJob, TLoaderCapabilities -from dlt.load.typing import LoadJobStatus +from dlt.load.client_base import JobClientBase, LoadJob +from dlt.load.typing import TLoadJobStatus from dlt.load.exceptions import (LoadJobNotExistsException, LoadJobInvalidStateTransitionException, LoadClientTerminalException, LoadClientTransientException) -from dlt.load.dummy.configuration import DummyClientConfiguration, configuration +from dlt.load.dummy import capabilities +from dlt.load.dummy.configuration import DummyClientConfiguration class LoadDummyJob(LoadJob): @@ -22,7 +23,7 @@ def __init__(self, file_name: str, fail_prob: float = 0.0, retry_prob: float = 0 self.retry_prob = retry_prob self.completed_prob = completed_prob self.timeout = timeout - self._status: LoadJobStatus = "running" + self._status: TLoadJobStatus = "running" self._exception: str = None self.start_time: float = pendulum.now().timestamp() super().__init__(file_name) @@ -33,7 +34,7 @@ def __init__(self, file_name: str, fail_prob: float = 0.0, retry_prob: float = 0 raise LoadClientTransientException(self._exception) - def status(self) -> LoadJobStatus: + def status(self) -> TLoadJobStatus: # this should poll the server for a job status, here we simulate various outcomes if self._status == "running": n = pendulum.now().timestamp() @@ -77,10 +78,10 @@ class DummyClient(JobClientBase): """ dummy client storing jobs in memory """ - CONFIG: DummyClientConfiguration = None - def __init__(self, schema: Schema) -> None: - pass + def __init__(self, schema: Schema, config: DummyClientConfiguration) -> None: + super().__init__(schema, config) + self.config: DummyClientConfiguration = config def initialize_storage(self) -> None: pass @@ -89,8 +90,8 @@ def update_storage_schema(self) -> None: pass def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: - job_id = JobClientBase.get_file_name_from_file_path(file_path) - file_name = JobClientBase.get_file_name_from_file_path(file_path) + job_id = FileStorage.get_file_name_from_file_path(file_path) + file_name = FileStorage.get_file_name_from_file_path(file_path) # return existing job if already there if job_id not in JOBS: JOBS[job_id] = self._create_job(file_name) @@ -102,7 +103,7 @@ def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: return JOBS[job_id] def restore_file_load(self, file_path: str) -> LoadJob: - job_id = JobClientBase.get_file_name_from_file_path(file_path) + job_id = FileStorage.get_file_name_from_file_path(file_path) if job_id not in JOBS: raise LoadJobNotExistsException(job_id) return JOBS[job_id] @@ -119,29 +120,12 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb def _create_job(self, job_id: str) -> LoadDummyJob: return LoadDummyJob( job_id, - fail_prob=self.CONFIG.fail_prob, - retry_prob=self.CONFIG.retry_prob, - completed_prob=self.CONFIG.completed_prob, - timeout=self.CONFIG.timeout + fail_prob=self.config.fail_prob, + retry_prob=self.config.retry_prob, + completed_prob=self.config.completed_prob, + timeout=self.config.timeout ) @classmethod - def capabilities(cls) -> TLoaderCapabilities: - return { - "preferred_loader_file_format": cls.CONFIG.loader_file_format, - "supported_loader_file_formats": [cls.CONFIG.loader_file_format], - "max_identifier_length": 127, - "max_column_length": 127, - "max_query_length": 8 * 1024 * 1024, - "is_max_query_length_in_bytes": True, - "max_text_data_type_length": 65535, - "is_max_text_data_type_length_in_bytes": True - } - - @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[DummyClientConfiguration, CredentialsConfiguration]: - cls.CONFIG = configuration(initial_values=initial_values) - return cls.CONFIG, None - - -CLIENT = DummyClient + def capabilities(cls) -> DestinationCapabilitiesContext: + return capabilities() diff --git a/dlt/load/exceptions.py b/dlt/load/exceptions.py index 62a1cc67cb..7944943b24 100644 --- a/dlt/load/exceptions.py +++ b/dlt/load/exceptions.py @@ -1,7 +1,7 @@ from typing import Sequence from dlt.common.exceptions import DltException, TerminalException, TransientException -from dlt.load.typing import LoadJobStatus +from dlt.load.typing import TLoadJobStatus class LoadException(DltException): @@ -44,7 +44,7 @@ def __init__(self, table_name: str, file_name: str) -> None: class LoadJobInvalidStateTransitionException(LoadClientTerminalException): - def __init__(self, from_state: LoadJobStatus, to_state: LoadJobStatus) -> None: + def __init__(self, from_state: TLoadJobStatus, to_state: TLoadJobStatus) -> None: self.from_state = from_state self.to_state = to_state super().__init__(f"Load job cannot transition form {from_state} to {to_state}") diff --git a/dlt/load/load.py b/dlt/load/load.py index 15a673181a..3980a3203c 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,27 +1,23 @@ -from typing import List, Optional, Tuple, Type, Protocol +from typing import List, Optional, Tuple from multiprocessing.pool import ThreadPool -from importlib import import_module from prometheus_client import REGISTRY, Counter, Gauge, CollectorRegistry, Summary from dlt.common import sleep, logger -from dlt.cli import TRunnerArgs -from dlt.common.runners import TRunMetrics, initialize_runner, run_pool, Runnable, workermethod +from dlt.common.configuration import with_config +from dlt.common.typing import ConfigValue +from dlt.common.runners import TRunMetrics, Runnable, workermethod from dlt.common.logger import pretty_format_exception from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchema from dlt.common.storages import LoadStorage from dlt.common.telemetry import get_logging_extras, set_gauge_all_labels -from dlt.common.typing import StrAny +from dlt.load.client_base import JobClientBase, DestinationReference, LoadJob +from dlt.load.client_base_impl import LoadEmptyJob +from dlt.load.typing import TLoadJobStatus +from dlt.load.configuration import LoaderConfiguration, DestinationClientConfiguration from dlt.load.exceptions import LoadClientTerminalException, LoadClientTransientException, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, LoadJobNotExistsException, LoadUnknownTableException -from dlt.load.client_base import JobClientBase, LoadJob -from dlt.load.typing import LoadJobStatus, TLoaderCapabilities -from dlt.load.configuration import configuration, LoaderConfiguration - - -class SupportsLoadClient(Protocol): - CLIENT: Type[JobClientBase] class Load(Runnable[ThreadPool]): @@ -31,9 +27,19 @@ class Load(Runnable[ThreadPool]): job_counter: Counter = None job_wait_summary: Summary = None - def __init__(self, config: LoaderConfiguration, collector: CollectorRegistry, client_initial_values: StrAny = None, is_storage_owner: bool = False) -> None: + @with_config(spec=LoaderConfiguration, namespaces=("load",)) + def __init__( + self, + destination: DestinationReference, + collector: CollectorRegistry = REGISTRY, + is_storage_owner: bool = False, + config: LoaderConfiguration = ConfigValue, + initial_client_config: DestinationClientConfiguration = ConfigValue + ) -> None: self.config = config - self.load_client_cls = self.import_client_cls(config.client_type, initial_values=client_initial_values) + self.initial_client_config = initial_client_config + self.destination = destination + self.capabilities = destination.capabilities() self.pool: ThreadPool = None self.load_storage: LoadStorage = self.create_storage(is_storage_owner) try: @@ -43,22 +49,12 @@ def __init__(self, config: LoaderConfiguration, collector: CollectorRegistry, cl if "Duplicated timeseries" not in str(v): raise - @staticmethod - def loader_capabilities(client_type: str) -> TLoaderCapabilities: - m: SupportsLoadClient = import_module(f"dlt.load.{client_type}.client") - return m.CLIENT.capabilities() - - @staticmethod - def import_client_cls(client_type: str, initial_values: StrAny = None) -> Type[JobClientBase]: - m: SupportsLoadClient = import_module(f"dlt.load.{client_type}.client") - m.CLIENT.configure(initial_values) - return m.CLIENT - def create_storage(self, is_storage_owner: bool) -> LoadStorage: load_storage = LoadStorage( is_storage_owner, - self.load_client_cls.capabilities()["preferred_loader_file_format"], - self.load_client_cls.capabilities()["supported_loader_file_formats"] + self.capabilities.preferred_loader_file_format, + self.capabilities.supported_loader_file_formats, + config=self.config.load_storage_config ) return load_storage @@ -86,10 +82,10 @@ def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> O # open new connection for each upload job: LoadJob = None try: - with self.load_client_cls(schema) as client: + with self.destination.client(schema, self.initial_client_config) as client: job_info = self.load_storage.parse_job_file_name(file_path) - if job_info.file_format not in client.capabilities()["supported_loader_file_formats"]: - raise LoadClientUnsupportedFileFormats(job_info.file_format, client.capabilities()["supported_loader_file_formats"], file_path) + if job_info.file_format not in self.capabilities.supported_loader_file_formats: + raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") table = self.get_load_table(schema, job_info.table_name, file_path) if table["write_disposition"] not in ["append", "replace"]: @@ -98,7 +94,7 @@ def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> O except (LoadClientTerminalException, TerminalValueError): # if job irreversibly cannot be started, mark it as failed logger.exception(f"Terminal problem with spooling job {file_path}") - job = JobClientBase.make_job_with_status(file_path, "failed", pretty_format_exception()) + job = LoadEmptyJob.from_file_path(file_path, "failed", pretty_format_exception()) except (LoadClientTransientException, Exception): # return no job so file stays in new jobs (root) folder logger.exception(f"Temporary problem with spooling job {file_path}") @@ -140,7 +136,7 @@ def retrieve_jobs(self, client: JobClientBase, load_id: str) -> Tuple[int, List[ job = client.restore_file_load(file_path) except LoadClientTerminalException: logger.exception(f"Job retrieval for {file_path} failed, job will be terminated") - job = JobClientBase.make_job_with_status(file_path, "failed", pretty_format_exception()) + job = LoadEmptyJob.from_file_path(file_path, "failed", pretty_format_exception()) # proceed to appending job, do not reraise except (LoadClientTransientException, Exception): # raise on all temporary exceptions, typically network / server problems @@ -160,7 +156,7 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob]) -> List[LoadJob]: for ii in range(len(jobs)): job = jobs[ii] logger.debug(f"Checking status for job {job.file_name()}") - status: LoadJobStatus = job.status() + status: TLoadJobStatus = job.status() final_location: str = None if status == "running": # ask again @@ -206,12 +202,12 @@ def run(self, pool: ThreadPool) -> TRunMetrics: schema = self.load_storage.load_package_schema(load_id) logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") # initialize analytical storage ie. create dataset required by passed schema - with self.load_client_cls(schema) as client: - logger.info(f"Client {self.config.client_type} will start load") + with self.destination.client(schema, self.initial_client_config) as client: + logger.info(f"Client for {client.config.destination_name} will start load") client.initialize_storage() schema_update = self.load_storage.begin_schema_update(load_id) if schema_update: - logger.info(f"Client {self.config.client_type} will update schema to package schema") + logger.info(f"Client for {client.config.destination_name} will update schema to package schema") # TODO: this should rather generate an SQL job(s) to be executed PRE loading client.update_storage_schema() self.load_storage.commit_schema_update(load_id) @@ -230,7 +226,7 @@ def run(self, pool: ThreadPool) -> TRunMetrics: ) # if there are no existing or new jobs we complete the package if jobs_count == 0: - with self.load_client_cls(schema) as client: + with self.destination.client(schema, self.initial_client_config) as client: # TODO: this script should be executed as a job (and contain also code to merge/upsert data and drop temp tables) # TODO: post loading jobs remaining_jobs = client.complete_load(load_id) @@ -250,18 +246,3 @@ def run(self, pool: ThreadPool) -> TRunMetrics: sleep(1) return TRunMetrics(False, False, len(self.load_storage.list_packages())) - - -def main(args: TRunnerArgs) -> int: - C = configuration(args._asdict()) - initialize_runner(C) - try: - load = Load(C, REGISTRY) - except Exception: - logger.exception("init module") - return -1 - return run_pool(C, load) - - -def run_main(args: TRunnerArgs) -> None: - exit(main(args)) diff --git a/dlt/load/redshift/__init__.py b/dlt/load/redshift/__init__.py index e69de29bb2..0123092700 100644 --- a/dlt/load/redshift/__init__.py +++ b/dlt/load/redshift/__init__.py @@ -0,0 +1,41 @@ +from typing import Type + +from dlt.common.schema.schema import Schema +from dlt.common.typing import ConfigValue +from dlt.common.configuration import with_config +from dlt.common.configuration.specs import DestinationCapabilitiesContext + +from dlt.load.client_base import JobClientBase +from dlt.load.configuration import DestinationClientConfiguration +from dlt.load.redshift.configuration import RedshiftClientConfiguration + + +@with_config(spec=RedshiftClientConfiguration, namespaces=("destination", "redshift",)) +def _configure(config: RedshiftClientConfiguration = ConfigValue) -> RedshiftClientConfiguration: + return config + + +def capabilities() -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.update({ + "preferred_loader_file_format": "insert_values", + "supported_loader_file_formats": ["insert_values"], + "max_identifier_length": 127, + "max_column_length": 127, + "max_query_length": 16 * 1024 * 1024, + "is_max_query_length_in_bytes": True, + "max_text_data_type_length": 65535, + "is_max_text_data_type_length_in_bytes": True + }) + return caps + + +def client(schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> JobClientBase: + # import client when creating instance so capabilities and config specs can be accessed without dependencies installed + from dlt.load.redshift.redshift import RedshiftClient + + return RedshiftClient(schema, _configure(initial_config)) # type: ignore + + +def spec() -> Type[DestinationClientConfiguration]: + return RedshiftClientConfiguration diff --git a/dlt/load/redshift/configuration.py b/dlt/load/redshift/configuration.py index 4f92b1a5bc..fe75be5457 100644 --- a/dlt/load/redshift/configuration.py +++ b/dlt/load/redshift/configuration.py @@ -1,19 +1,11 @@ -from typing import Tuple - -from dlt.common.typing import StrAny -from dlt.common.configuration import configspec, make_configuration +from dlt.common.configuration import configspec from dlt.common.configuration.specs import PostgresCredentials -from dlt.load.configuration import LoaderClientDwhConfiguration - +from dlt.load.configuration import DestinationClientDwhConfiguration -@configspec -class RedshiftClientConfiguration(LoaderClientDwhConfiguration): - client_type: str = "redshift" +@configspec(init=True) +class RedshiftClientConfiguration(DestinationClientDwhConfiguration): + destination_name: str = "redshift" + credentials: PostgresCredentials -def configuration(initial_values: StrAny = None) -> Tuple[RedshiftClientConfiguration, PostgresCredentials]: - return ( - make_configuration(RedshiftClientConfiguration(), initial_value=initial_values), - make_configuration(PostgresCredentials(), initial_value=initial_values) - ) diff --git a/dlt/load/redshift/client.py b/dlt/load/redshift/redshift.py similarity index 87% rename from dlt/load/redshift/client.py rename to dlt/load/redshift/redshift.py index b4a65b65af..136d06c1a5 100644 --- a/dlt/load/redshift/client.py +++ b/dlt/load/redshift/redshift.py @@ -1,5 +1,4 @@ import platform - if platform.python_implementation() == "PyPy": import psycopg2cffi as psycopg2 from psycopg2cffi.sql import SQL, Identifier, Composed, Literal as SQLLiteral @@ -10,20 +9,21 @@ from contextlib import contextmanager from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple -from dlt.common.configuration.specs.postgres_credentials import PostgresCredentials -from dlt.common.typing import StrAny from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.configuration.specs import PostgresCredentials, DestinationCapabilitiesContext from dlt.common.data_writers import escape_redshift_identifier from dlt.common.schema import COLUMN_HINTS, TColumnSchema, TColumnSchemaBase, TDataType, THintType, Schema, TTableSchemaColumns, add_missing_hints from dlt.common.schema.typing import TTableSchema, TWriteDisposition +from dlt.common.storages.file_storage import FileStorage -from dlt.load.exceptions import (LoadClientSchemaWillNotUpdate, LoadClientTerminalInnerException, - LoadClientTransientInnerException) -from dlt.load.typing import LoadJobStatus, DBCursor, TLoaderCapabilities -from dlt.load.client_base import JobClientBase, SqlClientBase, SqlJobClientBase, LoadJob +from dlt.load.exceptions import LoadClientSchemaWillNotUpdate, LoadClientTerminalInnerException, LoadClientTransientInnerException +from dlt.load.typing import TLoadJobStatus, DBCursor +from dlt.load.client_base import SqlClientBase, LoadJob +from dlt.load.client_base_impl import SqlJobClientBase, LoadEmptyJob -from dlt.load.redshift.configuration import configuration, RedshiftClientConfiguration +from dlt.load.redshift import capabilities +from dlt.load.redshift.configuration import RedshiftClientConfiguration SCT_TO_PGT: Dict[TDataType, str] = { @@ -57,14 +57,14 @@ class RedshiftSqlClient(SqlClientBase["psycopg2.connection"]): - def __init__(self, default_dataset_name: str, CREDENTIALS: PostgresCredentials) -> None: + def __init__(self, default_dataset_name: str, credentials: PostgresCredentials) -> None: super().__init__(default_dataset_name) self._conn: psycopg2.connection = None - self.C = CREDENTIALS + self.credentials = credentials def open_connection(self) -> None: self._conn = psycopg2.connect( - **self.C, + **self.credentials, options=f"-c search_path={self.fully_qualified_dataset_name()},public" ) # we'll provide explicit transactions @@ -140,12 +140,12 @@ def fully_qualified_dataset_name(self) -> str: class RedshiftInsertLoadJob(LoadJob): def __init__(self, table_name: str, write_disposition: TWriteDisposition, file_path: str, sql_client: SqlClientBase["psycopg2.connection"]) -> None: - super().__init__(JobClientBase.get_file_name_from_file_path(file_path)) + super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._sql_client = sql_client # insert file content immediately self._insert(sql_client.make_qualified_table_name(table_name), write_disposition, file_path) - def status(self) -> LoadJobStatus: + def status(self) -> TLoadJobStatus: # this job is always done return "completed" @@ -200,15 +200,16 @@ def _insert(self, qualified_table_name: str, write_disposition: TWriteDispositio class RedshiftClient(SqlJobClientBase): - CONFIG: RedshiftClientConfiguration = None - CREDENTIALS: PostgresCredentials = None + # CONFIG: RedshiftClientConfiguration = None + # CREDENTIALS: PostgresCredentials = None - def __init__(self, schema: Schema) -> None: + def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: sql_client = RedshiftSqlClient( - schema.normalize_make_dataset_name(self.CONFIG.default_dataset, self.CONFIG.default_schema_name, schema.name), - self.CREDENTIALS + schema.normalize_make_dataset_name(config.dataset_name, config.default_schema_name, schema.name), + config.credentials ) - super().__init__(schema, sql_client) + super().__init__(schema, config, sql_client) + self.config: RedshiftClientConfiguration = config self.sql_client = sql_client def initialize_storage(self) -> None: @@ -219,7 +220,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: # always returns completed jobs as RedshiftInsertLoadJob is executed # atomically in start_file_load so any jobs that should be recreated are already completed # in case of bugs in loader (asking for jobs that were never created) we are not able to detect that - return JobClientBase.make_job_with_status(file_path, "completed") + return LoadEmptyJob.from_file_path(file_path, "completed") def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: try: @@ -334,22 +335,5 @@ def _pq_t_to_sc_t(pq_t: str, precision: Optional[int], scale: Optional[int]) -> return PGT_TO_SCT.get(pq_t, "text") @classmethod - def capabilities(cls) -> TLoaderCapabilities: - return { - "preferred_loader_file_format": "insert_values", - "supported_loader_file_formats": ["insert_values"], - "max_identifier_length": 127, - "max_column_length": 127, - "max_query_length": 16 * 1024 * 1024, - "is_max_query_length_in_bytes": True, - "max_text_data_type_length": 65535, - "is_max_text_data_type_length_in_bytes": True - } - - @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[RedshiftClientConfiguration, PostgresCredentials]: - cls.CONFIG, cls.CREDENTIALS = configuration(initial_values=initial_values) - return cls.CONFIG, cls.CREDENTIALS - - -CLIENT = RedshiftClient + def capabilities(cls) -> DestinationCapabilitiesContext: + return capabilities() diff --git a/dlt/load/typing.py b/dlt/load/typing.py index b103cc6719..d6ec5f457a 100644 --- a/dlt/load/typing.py +++ b/dlt/load/typing.py @@ -1,24 +1,10 @@ -from typing import Any, AnyStr, List, Literal, Optional, Tuple, TypeVar, TypedDict +from typing import Any, AnyStr, List, Literal, Optional, Tuple, TypeVar -from dlt.common.data_writers import TLoaderFileFormat - - -LoadJobStatus = Literal["running", "failed", "retry", "completed"] +TLoadJobStatus = Literal["running", "failed", "retry", "completed"] # native connection TNativeConn = TypeVar("TNativeConn", bound="object") -class TLoaderCapabilities(TypedDict): - preferred_loader_file_format: TLoaderFileFormat - supported_loader_file_formats: List[TLoaderFileFormat] - max_identifier_length: int - max_column_length: int - max_query_length: int - is_max_query_length_in_bytes: bool - max_text_data_type_length: int - is_max_text_data_type_length_in_bytes: bool - - # type for dbapi cursor class DBCursor: closed: Any diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index e4ef6dcf70..4e25331bce 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -10,7 +10,7 @@ from dlt.load.exceptions import LoadJobNotExistsException, LoadJobServerTerminalException from dlt.load import Load -from dlt.load.bigquery.client import BigQueryClient +from dlt.load.bigquery.bigquery import BigQueryClient from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage, cm_yield_client_with_storage diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index e3f677e4d4..73a9d31fe7 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -4,10 +4,11 @@ from dlt.common.utils import custom_environ, uniq_id from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.configuration import make_configuration +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import GcpClientCredentials -from dlt.load.bigquery.client import BigQueryClient +from dlt.load.bigquery.bigquery import BigQueryClient +from dlt.load.bigquery.configuration import BigQueryClientConfiguration from dlt.load.exceptions import LoadClientSchemaWillNotUpdate from tests.load.utils import TABLE_UPDATE @@ -21,19 +22,18 @@ def schema() -> Schema: def test_configuration() -> None: # check names normalized with custom_environ({"GCP__PRIVATE_KEY": "---NO NEWLINE---\n"}): - C = make_configuration(GcpClientCredentials()) + C = resolve_configuration(GcpClientCredentials()) assert C.private_key == "---NO NEWLINE---\n" with custom_environ({"GCP__PRIVATE_KEY": "---WITH NEWLINE---\n"}): - C = make_configuration(GcpClientCredentials()) + C = resolve_configuration(GcpClientCredentials()) assert C.private_key == "---WITH NEWLINE---\n" @pytest.fixture def gcp_client(schema: Schema) -> BigQueryClient: # return client without opening connection - BigQueryClient.configure(initial_values={"default_dataset": uniq_id()}) - return BigQueryClient(schema) + return BigQueryClient(schema, BigQueryClientConfiguration(dataset_name="TEST" + uniq_id(), credentials=GcpClientCredentials())) def test_create_table(gcp_client: BigQueryClient) -> None: diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index a5353b3253..1416f63201 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -10,7 +10,7 @@ from dlt.load.exceptions import LoadClientTerminalInnerException from dlt.load import Load -from dlt.load.redshift.client import RedshiftClient, RedshiftInsertLoadJob, psycopg2 +from dlt.load.redshift.redshift import RedshiftClient, RedshiftInsertLoadJob, psycopg2 from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index 0c2fcfd77c..e007d8b37a 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -4,11 +4,12 @@ from dlt.common.utils import uniq_id, custom_environ from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.configuration import make_configuration +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import PostgresCredentials from dlt.load.exceptions import LoadClientSchemaWillNotUpdate -from dlt.load.redshift.client import RedshiftClient +from dlt.load.redshift.redshift import RedshiftClient +from dlt.load.redshift.configuration import RedshiftClientConfiguration from tests.load.utils import TABLE_UPDATE @@ -21,14 +22,13 @@ def schema() -> Schema: @pytest.fixture def client(schema: Schema) -> RedshiftClient: # return client without opening connection - RedshiftClient.configure(initial_values={"default_dataset": "TEST" + uniq_id()}) - return RedshiftClient(schema) + return RedshiftClient(schema, RedshiftClientConfiguration(dataset_name="TEST" + uniq_id())) def test_configuration() -> None: # check names normalized with custom_environ({"PG__DBNAME": "UPPER_CASE_DATABASE", "PG__PASSWORD": " pass\n"}): - C = make_configuration(PostgresCredentials()) + C = resolve_configuration(PostgresCredentials()) assert C.dbname == "upper_case_database" assert C.password == "pass" diff --git a/tests/load/test_client.py b/tests/load/test_client.py index d3d66d60eb..89dba8dd45 100644 --- a/tests/load/test_client.py +++ b/tests/load/test_client.py @@ -10,7 +10,8 @@ from dlt.common.schema import TTableSchemaColumns from dlt.common.utils import uniq_id -from dlt.load.client_base import DBCursor, SqlJobClientBase +from dlt.load.client_base import DBCursor +from dlt.load.client_base_impl import SqlJobClientBase from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.common.utils import load_json_case @@ -335,12 +336,12 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No assert r_job.status() == "completed" -@pytest.mark.parametrize('client_type', ALL_CLIENT_TYPES) -def test_default_schema_name_init_storage(client_type: str) -> None: - with cm_yield_client_with_storage(client_type, initial_values={ - "default_schema_name": "event" # pass the schema that is a default schema. that should create dataset with the name `default_dataset` +@pytest.mark.parametrize('destination_name', ALL_CLIENT_TYPES) +def test_default_schema_name_init_storage(destination_name: str) -> None: + with cm_yield_client_with_storage(destination_name, initial_values={ + "default_schema_name": "event" # pass the schema that is a default schema. that should create dataset with the name `dataset_name` }) as client: - assert client.sql_client.default_dataset_name == client.CONFIG.default_dataset + assert client.sql_client.default_dataset_name == client.config.dataset_name def prepare_schema(client: SqlJobClientBase, case: str) -> None: diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 618c7a2e1b..db8dab7230 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -1,6 +1,5 @@ import shutil import os -from os import environ from multiprocessing.pool import ThreadPool from typing import List, Sequence, Tuple import pytest @@ -13,14 +12,16 @@ from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.typing import StrAny from dlt.common.utils import uniq_id -from dlt.load.client_base import JobClientBase, LoadEmptyJob, LoadJob -from dlt.load.configuration import configuration, LoaderConfiguration -from dlt.load.dummy import client from dlt.load import Load +from dlt.load.client_base import DestinationReference, LoadJob +from dlt.load.client_base_impl import LoadEmptyJob + +from dlt.load import dummy +from dlt.load.dummy import dummy as dummy_impl from dlt.load.dummy.configuration import DummyClientConfiguration -from tests.utils import clean_test_storage, init_logger +from tests.utils import clean_test_storage, init_logger, TEST_DICT_CONFIG_PROVIDER NORMALIZED_FILES = [ @@ -31,7 +32,7 @@ @pytest.fixture(autouse=True) def storage() -> FileStorage: - clean_test_storage(init_normalize=True, init_loader=True) + return clean_test_storage(init_normalize=True, init_loader=True) @pytest.fixture(scope="module", autouse=True) @@ -39,14 +40,6 @@ def logger_autouse() -> None: init_logger() -def test_gen_configuration() -> None: - load = setup_loader() - assert LoaderConfiguration in type(load.config).mro() - # mock missing config values - load = setup_loader(initial_values={"load_volume_path": LoaderConfiguration.load_volume_path}) - assert LoaderConfiguration in type(load.config).mro() - - def test_spool_job_started() -> None: # default config keeps the job always running load = setup_loader() @@ -59,7 +52,7 @@ def test_spool_job_started() -> None: jobs: List[LoadJob] = [] for f in files: job = Load.w_spool_job(load, f, load_id, schema) - assert type(job) is client.LoadDummyJob + assert type(job) is dummy_impl.LoadDummyJob assert job.status() == "running" assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name())) jobs.append(job) @@ -96,7 +89,7 @@ def test_unsupported_write_disposition() -> None: def test_spool_job_failed() -> None: # this config fails job on start - load = setup_loader(initial_client_values={"fail_prob" : 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -121,7 +114,7 @@ def test_spool_job_failed() -> None: def test_spool_job_retry_new() -> None: # this config retries job on start (transient fail) - load = setup_loader(initial_client_values={"retry_prob" : 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -141,7 +134,7 @@ def test_spool_job_retry_new() -> None: def test_spool_job_retry_started() -> None: # this config keeps the job always running load = setup_loader() - client.CLIENT_CONFIG = DummyClientConfiguration + # dummy_impl.CLIENT_CONFIG = DummyClientConfiguration load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -150,7 +143,7 @@ def test_spool_job_retry_started() -> None: jobs: List[LoadJob] = [] for f in files: job = Load.w_spool_job(load, f, load_id, schema) - assert type(job) is client.LoadDummyJob + assert type(job) is dummy_impl.LoadDummyJob assert job.status() == "running" assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name())) # mock job config to make it retry @@ -162,7 +155,7 @@ def test_spool_job_retry_started() -> None: remaining_jobs = load.complete_jobs(load_id, jobs) assert len(remaining_jobs) == 0 # clear retry flag - client.JOBS = {} + dummy_impl.JOBS = {} files = load.load_storage.list_new_jobs(load_id) assert len(files) == 2 # parse the new job names @@ -183,10 +176,10 @@ def test_try_retrieve_job() -> None: # manually move jobs to started files = load.load_storage.list_new_jobs(load_id) for f in files: - load.load_storage.start_job(load_id, JobClientBase.get_file_name_from_file_path(f)) + load.load_storage.start_job(load_id, FileStorage.get_file_name_from_file_path(f)) # dummy client may retrieve jobs that it created itself, jobs in started folder are unknown # and returned as terminal - with load.load_client_cls(schema) as c: + with load.destination.client(schema, load.initial_client_config) as c: job_count, jobs = load.retrieve_jobs(c, load_id) assert job_count == 2 for j in jobs: @@ -200,7 +193,7 @@ def test_try_retrieve_job() -> None: jobs_count, jobs = load.spool_new_jobs(load_id, schema) assert jobs_count == 2 # now jobs are known - with load.load_client_cls(schema) as c: + with load.destination.client(schema, load.initial_client_config) as c: job_count, jobs = load.retrieve_jobs(c, load_id) assert job_count == 2 for j in jobs: @@ -208,27 +201,27 @@ def test_try_retrieve_job() -> None: def test_completed_loop() -> None: - load = setup_loader(initial_client_values={"completed_prob": 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) assert_complete_job(load, load.load_storage.storage) def test_failed_loop() -> None: # ask to delete completed - load = setup_loader(initial_values={"delete_completed_jobs": True}, initial_client_values={"fail_prob": 1.0}) + load = setup_loader(delete_completed_jobs=True, client_config=DummyClientConfiguration(fail_prob=1.0)) # actually not deleted because one of the jobs failed assert_complete_job(load, load.load_storage.storage, should_delete_completed=False) def test_completed_loop_with_delete_completed() -> None: - load = setup_loader(initial_client_values={"completed_prob": 1.0}) - load.config.delete_completed_jobs = True + load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) load.load_storage = load.create_storage(is_storage_owner=False) + load.load_storage.config.delete_completed_jobs = True assert_complete_job(load, load.load_storage.storage, should_delete_completed=True) def test_retry_on_new_loop() -> None: # test job that retries sitting in new jobs - load = setup_loader(initial_client_values={"retry_prob" : 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -244,7 +237,7 @@ def test_retry_on_new_loop() -> None: files = load.load_storage.list_new_jobs(load_id) assert len(files) == 2 # jobs will be completed - load = setup_loader(initial_client_values={"completed_prob" : 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) load.run(ThreadPool()) files = load.load_storage.list_new_jobs(load_id) assert len(files) == 0 @@ -284,7 +277,7 @@ def assert_complete_job(load: Load, storage: FileStorage, should_delete_complete NORMALIZED_FILES ) # will complete all jobs - with patch.object(client.DummyClient, "complete_load") as complete_load: + with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: load.run(ThreadPool()) # did process schema update assert storage.has_file(os.path.join(load.load_storage.get_package_path(load_id), LoadStorage.PROCESSED_SCHEMA_UPDATES_FILE_NAME)) @@ -317,24 +310,18 @@ def prepare_load_package(load_storage: LoadStorage, cases: Sequence[str]) -> Tup return load_id, schema -def setup_loader(initial_values: StrAny = None, initial_client_values: StrAny = None) -> Load: +def setup_loader(delete_completed_jobs: bool = False, client_config: DummyClientConfiguration = None) -> Load: # reset jobs for a test - client.JOBS = {} - - default_values = { - "client_type": "dummy", - "delete_completed_jobs": False - } - default_client_values = { - "loader_file_format": "jsonl" - } - if initial_values: - default_values.update(initial_values) - if initial_client_values: - default_client_values.update(initial_client_values) + dummy_impl.JOBS = {} + destination: DestinationReference = dummy + client_config = client_config or DummyClientConfiguration(loader_file_format="jsonl") + # patch destination to provide client_config + # destination.client = lambda schema: dummy_impl.DummyClient(schema, client_config) + # setup loader - return Load( - configuration(initial_values=default_values), - CollectorRegistry(auto_describe=True), - client_initial_values=default_client_values + with TEST_DICT_CONFIG_PROVIDER().values({"delete_completed_jobs": delete_completed_jobs}): + return Load( + destination, + CollectorRegistry(auto_describe=True), + initial_client_config=client_config ) diff --git a/tests/load/utils.py b/tests/load/utils.py index be067595ed..8cd0ff8ede 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1,9 +1,10 @@ import contextlib +from importlib import import_module import os from typing import Any, ContextManager, Iterable, Iterator, List, Sequence, cast, IO from dlt.common import json, Decimal -from dlt.common.configuration import make_configuration +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs.schema_volume_configuration import SchemaVolumeConfiguration from dlt.common.data_writers import DataWriter from dlt.common.schema import TColumnSchema, TTableSchemaColumns @@ -14,7 +15,9 @@ from dlt.common.utils import uniq_id from dlt.load import Load -from dlt.load.client_base import JobClientBase, LoadJob, SqlJobClientBase +from dlt.load.client_base import DestinationReference, JobClientBase, LoadJob +from dlt.load.client_base_impl import SqlJobClientBase +from dlt.load.configuration import DestinationClientDwhConfiguration TABLE_UPDATE: List[TColumnSchema] = [ { @@ -105,30 +108,43 @@ def prepare_table(client: JobClientBase, case_name: str = "event_user", table_na -def yield_client_with_storage(client_type: str, initial_values: StrAny = None) -> Iterator[SqlJobClientBase]: +def yield_client_with_storage(destination_name: str, initial_values: StrAny = None) -> Iterator[SqlJobClientBase]: os.environ.pop("DEFAULT_DATASET", None) + # import destination reference by name + destination: DestinationReference = import_module(f"dlt.load.{destination_name}") # create dataset with random name - default_dataset = "test_" + uniq_id() - client_initial_values = {"default_dataset": default_dataset} + dataset_name = "test_" + uniq_id() + # create initial config + config: DestinationClientDwhConfiguration = None + config = destination.spec()() + # print(config.destination_name) + # print(destination.spec()) + # print(destination.spec().destination_name) + config.dataset_name = dataset_name + if initial_values is not None: - client_initial_values.update(initial_values) + # apply the values to credentials, if dict is provided it will be used as initial + config.credentials = initial_values + # also apply to config + config.update(initial_values) # get event default schema - C = make_configuration(SchemaVolumeConfiguration(), initial_value={ + C = resolve_configuration(SchemaVolumeConfiguration(), initial_value={ "schema_volume_path": "tests/common/cases/schemas/rasa" }) schema_storage = SchemaStorage(C) schema = schema_storage.load_schema("event") # create client and dataset client: SqlJobClientBase = None - with Load.import_client_cls(client_type, initial_values=client_initial_values)(schema) as client: + + with destination.client(schema, config) as client: client.initialize_storage() yield client client.sql_client.drop_dataset() @contextlib.contextmanager -def cm_yield_client_with_storage(client_type: str, initial_values: StrAny = None) -> ContextManager[SqlJobClientBase]: - return yield_client_with_storage(client_type, initial_values) +def cm_yield_client_with_storage(destination_name: str, initial_values: StrAny = None) -> ContextManager[SqlJobClientBase]: + return yield_client_with_storage(destination_name, initial_values) def write_dataset(client: JobClientBase, f: IO[Any], rows: Sequence[StrAny], columns_schema: TTableSchemaColumns) -> None: From 3f40da6d0cc270d05ceae2de7b9d281cb67994e9 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 18 Oct 2022 13:31:10 +0200 Subject: [PATCH 39/66] implements first version of pipeline v2 and applies config injection mechanics everywhere --- dlt/common/data_writers/writers.py | 3 +- dlt/common/normalizers/names/snake_case.py | 4 +- dlt/common/storages/file_storage.py | 21 +- dlt/common/storages/live_schema_storage.py | 19 +- dlt/common/storages/normalize_storage.py | 9 +- dlt/common/storages/schema_storage.py | 50 +- dlt/common/validation.py | 4 +- dlt/dbt_runner/configuration.py | 6 +- dlt/extract/extract.py | 14 +- dlt/extract/pipe.py | 21 +- dlt/extract/sources.py | 4 + dlt/normalize/configuration.py | 11 +- dlt/normalize/normalize.py | 14 +- dlt/pipeline/exceptions.py | 4 +- dlt/pipeline/pipeline.py | 4 +- experiments/pipeline/__init__.py | 21 +- experiments/pipeline/configuration.py | 26 +- experiments/pipeline/credentials.py | 22 - experiments/pipeline/pipeline.py | 539 +++++++++---------- experiments/pipeline/typing.py | 20 +- tests/common/runners/test_runners.py | 6 +- tests/common/schema/test_schema.py | 4 +- tests/common/storages/test_loader_storage.py | 20 +- tests/common/storages/test_schema_storage.py | 12 +- tests/dbt_runner/test_runner_bigquery.py | 2 +- tests/dbt_runner/test_runner_redshift.py | 8 +- tests/normalize/test_normalize.py | 31 +- tests/utils.py | 18 +- 28 files changed, 452 insertions(+), 465 deletions(-) delete mode 100644 experiments/pipeline/credentials.py diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 6656afac92..716b4bdf99 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -9,8 +9,7 @@ from dlt.common.json import json_typed_dumps from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal - -TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values"] +from dlt.common.configuration.specs.destination_capabilities_context import TLoaderFileFormat @dataclass diff --git a/dlt/common/normalizers/names/snake_case.py b/dlt/common/normalizers/names/snake_case.py index 639d91f089..efeb00c0fb 100644 --- a/dlt/common/normalizers/names/snake_case.py +++ b/dlt/common/normalizers/names/snake_case.py @@ -48,10 +48,10 @@ def normalize_schema_name(name: str) -> str: # build full db dataset (dataset) name out of (normalized) default dataset and schema name -def normalize_make_dataset_name(default_dataset: str, default_schema_name: str, schema_name: str) -> str: +def normalize_make_dataset_name(dataset_name: str, default_schema_name: str, schema_name: str) -> str: if schema_name is None: raise ValueError("schema_name is None") - name = normalize_column_name(default_dataset) + name = normalize_column_name(dataset_name) if default_schema_name is None or schema_name != default_schema_name: name += "_" + schema_name diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index b6cf1d314d..400d162426 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -18,10 +18,6 @@ def __init__(self, if makedirs: os.makedirs(storage_path, exist_ok=True) - # @classmethod - # def from_file(cls, file_path: str, file_type: str = "t",) -> "FileStorage": - # return cls(os.path.dirname(file_path), file_type) - def save(self, relative_path: str, data: Any) -> str: return self.save_atomic(self.storage_path, relative_path, data, file_type=self.file_type) @@ -108,19 +104,6 @@ def list_folder_dirs(self, relative_path: str, to_root: bool = True) -> List[str def create_folder(self, relative_path: str, exists_ok: bool = False) -> None: os.makedirs(self.make_full_path(relative_path), exist_ok=exists_ok) - # def copy_cross_storage_atomically(self, dest_volume_root: str, dest_relative_path: str, source_path: str, dest_name: str) -> None: - # external_tmp_file = tempfile.mktemp(dir=dest_volume_root) - # # first copy to temp file - # shutil.copy(self.make_full_path(source_path), external_tmp_file) - # # then rename to dest name - # external_dest = os.path.join(dest_volume_root, dest_relative_path, dest_name) - # try: - # os.rename(external_tmp_file, external_dest) - # except Exception: - # if os.path.isfile(external_tmp_file): - # os.remove(external_tmp_file) - # raise - def link_hard(self, from_relative_path: str, to_relative_path: str) -> None: # note: some interesting stuff on links https://lightrun.com/answers/conan-io-conan-research-investigate-symlinks-and-hard-links os.link( @@ -157,6 +140,10 @@ def make_full_path(self, path: str) -> str: # then assume that it is a path relative to storage root return os.path.join(self.storage_path, path) + @staticmethod + def get_file_name_from_file_path(file_path: str) -> str: + return os.path.basename(file_path) + @staticmethod def validate_file_name_component(name: str) -> None: # Universal platform bans several characters allowed in POSIX ie. | < \ or "COM1" :) diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index 3c04ec7439..0a3551b4eb 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -1,5 +1,6 @@ -from typing import Any, Dict +from typing import Any, Dict, overload +from dlt.common.typing import ConfigValue from dlt.common.schema.schema import Schema from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.configuration.specs import SchemaVolumeConfiguration @@ -7,9 +8,17 @@ class LiveSchemaStorage(SchemaStorage): - def __init__(self, C: SchemaVolumeConfiguration = None, makedirs: bool = False) -> None: + @overload + def __init__(self, config: SchemaVolumeConfiguration, makedirs: bool = False) -> None: + ... + + @overload + def __init__(self, config: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: + ... + + def __init__(self, config: SchemaVolumeConfiguration = None, makedirs: bool = False) -> None: self.live_schemas: Dict[str, Schema] = {} - super().__init__(C, makedirs) + super().__init__(config, makedirs) def __getitem__(self, name: str) -> Schema: # disconnect live schema @@ -36,9 +45,9 @@ def commit_live_schema(self, name: str) -> Schema: live_schema = self.live_schemas.get(name) if live_schema and live_schema.stored_version_hash != live_schema.version_hash: live_schema.bump_version() - if self.C.import_schema_path: + if self.config.import_schema_path: # overwrite import schemas if specified - self._export_schema(live_schema, self.C.import_schema_path) + self._export_schema(live_schema, self.config.import_schema_path) else: # write directly to schema storage if no import schema folder configured self._save_schema(live_schema) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index b668f86aaa..e20652e7b5 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, NamedTuple, overload +from typing import ClassVar, Sequence, NamedTuple, overload from itertools import groupby from pathlib import Path @@ -17,8 +17,8 @@ class TParsedNormalizeFileName(NamedTuple): class NormalizeStorage(VersionedStorage): - STORAGE_VERSION = "1.0.0" - EXTRACTED_FOLDER: str = "extracted" # folder within the volume where extracted files to be normalized are stored + STORAGE_VERSION: ClassVar[str] = "1.0.0" + EXTRACTED_FOLDER: ClassVar[str] = "extracted" # folder within the volume where extracted files to be normalized are stored @overload def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration) -> None: @@ -32,10 +32,13 @@ def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration = Config def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration = ConfigValue) -> None: super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(config.normalize_volume_path, "t", makedirs=is_owner)) self.config = config + print(is_owner) if is_owner: self.initialize_storage() def initialize_storage(self) -> None: + print(self.storage.storage_path) + print(NormalizeStorage.EXTRACTED_FOLDER) self.storage.create_folder(NormalizeStorage.EXTRACTED_FOLDER, exists_ok=True) def list_files_to_normalize_sorted(self) -> Sequence[str]: diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 305a901c33..309469695b 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -19,17 +19,17 @@ class SchemaStorage(Mapping[str, Schema]): NAMED_SCHEMA_FILE_PATTERN = f"%s_{SCHEMA_FILE_NAME}" @overload - def __init__(self, C: SchemaVolumeConfiguration, makedirs: bool = False) -> None: + def __init__(self, config: SchemaVolumeConfiguration, makedirs: bool = False) -> None: ... @overload - def __init__(self, C: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: + def __init__(self, config: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: ... @with_config(spec=SchemaVolumeConfiguration, namespaces=("schema",)) - def __init__(self, C: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: - self.C = C - self.storage = FileStorage(C.schema_volume_path, makedirs=makedirs) + def __init__(self, config: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: + self.config = config + self.storage = FileStorage(config.schema_volume_path, makedirs=makedirs) def load_schema(self, name: str) -> Schema: # loads a schema from a store holding many schemas @@ -39,21 +39,21 @@ def load_schema(self, name: str) -> Schema: storage_schema = json.loads(self.storage.load(schema_file)) # prevent external modifications of schemas kept in storage if not verify_schema_hash(storage_schema, empty_hash_verifies=True): - raise InStorageSchemaModified(name, self.C.schema_volume_path) + raise InStorageSchemaModified(name, self.config.schema_volume_path) except FileNotFoundError: # maybe we can import from external storage pass # try to import from external storage - if self.C.import_schema_path: + if self.config.import_schema_path: return self._maybe_import_schema(name, storage_schema) if storage_schema is None: - raise SchemaNotFoundError(name, self.C.schema_volume_path) + raise SchemaNotFoundError(name, self.config.schema_volume_path) return Schema.from_dict(storage_schema) def save_schema(self, schema: Schema) -> str: # check if there's schema to import - if self.C.import_schema_path: + if self.config.import_schema_path: try: imported_schema = Schema.from_dict(self._load_import_schema(schema.name)) # link schema being saved to current imported schema so it will not overwrite this save when loaded @@ -62,8 +62,8 @@ def save_schema(self, schema: Schema) -> str: # just save the schema pass path = self._save_schema(schema) - if self.C.export_schema_path: - self._export_schema(schema, self.C.export_schema_path) + if self.config.export_schema_path: + self._export_schema(schema, self.config.export_schema_path) return path def remove_schema(self, name: str) -> None: @@ -120,37 +120,37 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> except FileNotFoundError: # no schema to import -> skip silently and return the original if storage_schema is None: - raise SchemaNotFoundError(name, self.C.schema_volume_path, self.C.import_schema_path, self.C.external_schema_format) + raise SchemaNotFoundError(name, self.config.schema_volume_path, self.config.import_schema_path, self.config.external_schema_format) rv_schema = Schema.from_dict(storage_schema) assert rv_schema is not None return rv_schema def _load_import_schema(self, name: str) -> DictStrAny: - import_storage = FileStorage(self.C.import_schema_path, makedirs=False) - schema_file = self._file_name_in_store(name, self.C.external_schema_format) + import_storage = FileStorage(self.config.import_schema_path, makedirs=False) + schema_file = self._file_name_in_store(name, self.config.external_schema_format) imported_schema: DictStrAny = None imported_schema_s = import_storage.load(schema_file) - if self.C.external_schema_format == "json": + if self.config.external_schema_format == "json": imported_schema = json.loads(imported_schema_s) - elif self.C.external_schema_format == "yaml": + elif self.config.external_schema_format == "yaml": imported_schema = yaml.safe_load(imported_schema_s) else: - raise ValueError(self.C.external_schema_format) + raise ValueError(self.config.external_schema_format) return imported_schema def _export_schema(self, schema: Schema, export_path: str) -> None: - if self.C.external_schema_format == "json": - exported_schema_s = schema.to_pretty_json(remove_defaults=self.C.external_schema_format_remove_defaults) - elif self.C.external_schema_format == "yaml": - exported_schema_s = schema.to_pretty_yaml(remove_defaults=self.C.external_schema_format_remove_defaults) + if self.config.external_schema_format == "json": + exported_schema_s = schema.to_pretty_json(remove_defaults=self.config.external_schema_format_remove_defaults) + elif self.config.external_schema_format == "yaml": + exported_schema_s = schema.to_pretty_yaml(remove_defaults=self.config.external_schema_format_remove_defaults) else: - raise ValueError(self.C.external_schema_format) + raise ValueError(self.config.external_schema_format) export_storage = FileStorage(export_path, makedirs=True) - schema_file = self._file_name_in_store(schema.name, self.C.external_schema_format) + schema_file = self._file_name_in_store(schema.name, self.config.external_schema_format) export_storage.save(schema_file, exported_schema_s) - logger.info(f"Schema {schema.name} exported to {export_path} with version {schema.stored_version} as {self.C.external_schema_format}") + logger.info(f"Schema {schema.name} exported to {export_path} with version {schema.stored_version} as {self.config.external_schema_format}") def _save_schema(self, schema: Schema) -> str: # save a schema to schema store @@ -162,5 +162,3 @@ def _file_name_in_store(self, name: str, fmt: TSchemaFileFormat) -> str: return SchemaStorage.NAMED_SCHEMA_FILE_PATTERN % (name, fmt) else: return SchemaStorage.SCHEMA_FILE_NAME % fmt - -SchemaStorage(makedirs=True) \ No newline at end of file diff --git a/dlt/common/validation.py b/dlt/common/validation.py index 9fc1349c10..c13e54dd8b 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -8,13 +8,13 @@ TCustomValidator = Callable[[str, str, Any, Any], bool] -def validate_dict(schema: Type[_TypedDict], doc: StrAny, path: str, filter_f: TFilterFuc = None, validator_f: TCustomValidator = None) -> None: +def validate_dict(spec: Type[_TypedDict], doc: StrAny, path: str, filter_f: TFilterFuc = None, validator_f: TCustomValidator = None) -> None: # pass through filter filter_f = filter_f or (lambda _: True) # cannot validate anything validator_f = validator_f or (lambda p, pk, pv, t: False) - allowed_props = get_type_hints(schema) + allowed_props = get_type_hints(spec) required_props = {k: v for k, v in allowed_props.items() if not is_optional_type(v)} # remove optional props props = {k: v for k, v in doc.items() if filter_f(k)} diff --git a/dlt/dbt_runner/configuration.py b/dlt/dbt_runner/configuration.py index 888c8b7993..2cb17831c1 100644 --- a/dlt/dbt_runner/configuration.py +++ b/dlt/dbt_runner/configuration.py @@ -3,7 +3,7 @@ from typing import List, Optional, Type from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration import make_configuration, configspec +from dlt.common.configuration import resolve_configuration, configspec from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials @@ -40,7 +40,7 @@ def gen_configuration_variant(initial_values: StrAny = None) -> DBTRunnerConfigu DBTRunnerConfigurationImpl: Type[DBTRunnerConfiguration] environ = EnvironProvider() - source_schema_prefix: str = environ.get_value("default_dataset", type(str)) # type: ignore + source_schema_prefix: str = environ.get_value("dataset_name", type(str)) # type: ignore if environ.get_value("project_id", type(str), GcpClientCredentials.__namespace__): @configspec @@ -54,4 +54,4 @@ class DBTRunnerConfigurationGcp(GcpClientCredentials, DBTRunnerConfiguration): SOURCE_SCHEMA_PREFIX: str = source_schema_prefix DBTRunnerConfigurationImpl = DBTRunnerConfigurationGcp - return make_configuration(DBTRunnerConfigurationImpl(), initial_value=initial_values) + return resolve_configuration(DBTRunnerConfigurationImpl(), initial_value=initial_values) diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 0b9d96578e..c048880948 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -1,5 +1,5 @@ import os -from typing import List +from typing import ClassVar, List from dlt.common.utils import uniq_id from dlt.common.sources import TDirectDataItem, TDataItem @@ -13,14 +13,11 @@ class ExtractorStorage(DataItemStorage, NormalizeStorage): - EXTRACT_FOLDER = "extract" + EXTRACT_FOLDER: ClassVar[str] = "extract" def __init__(self, C: NormalizeVolumeConfiguration) -> None: # data item storage with jsonl with pua encoding - super().__init__("puae-jsonl", False, C) - self.initialize_storage() - - def initialize_storage(self) -> None: + super().__init__("puae-jsonl", True, C) self.storage.create_folder(ExtractorStorage.EXTRACT_FOLDER, exists_ok=True) def create_extract_id(self) -> str: @@ -49,7 +46,8 @@ def _get_extract_path(self, extract_id: str) -> str: return os.path.join(ExtractorStorage.EXTRACT_FOLDER, extract_id) -def extract(source: DltSource, storage: ExtractorStorage) -> TSchemaUpdate: +def extract(source: DltSource, storage: ExtractorStorage, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> TSchemaUpdate: + # TODO: add metrics: number of items processed, also per resource and table dynamic_tables: TSchemaUpdate = {} schema = source.schema extract_id = storage.create_extract_id() @@ -80,7 +78,7 @@ def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: _write_item(table_name, item) # yield from all selected pipes - for pipe_item in PipeIterator.from_pipes(source.pipes): + for pipe_item in PipeIterator.from_pipes(source.pipes, max_parallel_items=max_parallel_items, workers=workers, futures_poll_interval=futures_poll_interval): # get partial table from table template resource = source.resource_by_pipe(pipe_item.pipe) if resource._table_name_hint_fun: diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 18142b6dbe..bc22e77ec6 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -193,15 +193,15 @@ def __repr__(self) -> str: class PipeIterator(Iterator[PipeItem]): @configspec - class PipeIteratorConfiguration: + class PipeIteratorConfiguration(BaseConfiguration): max_parallel_items: int = 100 - worker_threads: int = 5 + workers: int = 5 futures_poll_interval: float = 0.01 - def __init__(self, max_parallel_items: int, worker_threads, futures_poll_interval: float) -> None: + def __init__(self, max_parallel_items: int, workers: int, futures_poll_interval: float) -> None: self.max_parallel_items = max_parallel_items - self.worker_threads = worker_threads + self.workers = workers self.futures_poll_interval = futures_poll_interval self._async_pool: asyncio.AbstractEventLoop = None @@ -212,21 +212,21 @@ def __init__(self, max_parallel_items: int, worker_threads, futures_poll_interva @classmethod @with_config(spec=PipeIteratorConfiguration) - def from_pipe(cls, pipe: Pipe, *, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + def from_pipe(cls, pipe: Pipe, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": if pipe.parent: pipe = pipe.full_pipe() # head must be iterator assert isinstance(pipe.head, Iterator) # create extractor - extract = cls(max_parallelism, worker_threads, futures_poll_interval) + extract = cls(max_parallel_items, workers, futures_poll_interval) # add as first source extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) return extract @classmethod @with_config(spec=PipeIteratorConfiguration) - def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, *, max_parallelism: int = 100, worker_threads: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": - extract = cls(max_parallelism, worker_threads, futures_poll_interval) + def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + extract = cls(max_parallel_items, workers, futures_poll_interval) # clone all pipes before iterating (recursively) as we will fork them and this add steps pipes = PipeIterator.clone_pipes(pipes) @@ -247,7 +247,6 @@ def _fork_pipeline(pipe: Pipe) -> None: print("add to sources: " + pipe.name) extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) - for pipe in reversed(pipes): _fork_pipeline(pipe) @@ -315,7 +314,7 @@ def __next__(self) -> PipeItem: step = pipe_item.pipe[pipe_item.step + 1] assert callable(step) item = step(pipe_item.item) - pipe_item = ResolvablePipeItem(item, pipe_item.step + 1, pipe_item.pipe) # type: ignore + pipe_item = ResolvablePipeItem(item, pipe_item.step + 1, pipe_item.pipe) def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: @@ -339,7 +338,7 @@ def _ensure_thread_pool(self) -> ThreadPoolExecutor: if self._thread_pool: return self._thread_pool - self._thread_pool = ThreadPoolExecutor(self.worker_threads) + self._thread_pool = ThreadPoolExecutor(self.workers) return self._thread_pool def __enter__(self) -> "PipeIterator": diff --git a/dlt/extract/sources.py b/dlt/extract/sources.py index c4703b1d09..8f8eef272e 100644 --- a/dlt/extract/sources.py +++ b/dlt/extract/sources.py @@ -58,6 +58,10 @@ def _resolve_hint(hint: Union[Any, TFunDataItemDynHint]) -> Any: return cast(TPartialTableSchema, self._table_schema_template) def _set_template(self, table_schema_template: TTableSchemaTemplate) -> None: + # validate template + # TODO: name must be set if any other properties are set + # TODO: remove all none values + # if "name" is callable in the template then the table schema requires actual data item to be inferred name_hint = table_schema_template.get("name") if callable(name_hint): diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index 1a924520d6..1dd4b3b659 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -1,10 +1,11 @@ -from dlt.common.typing import StrAny -from dlt.common.data_writers import TLoaderFileFormat -from dlt.common.configuration import make_configuration, configspec -from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, PoolRunnerConfiguration, DestinationCapabilitiesContext, TPoolType @configspec(init=True) class NormalizeConfiguration(PoolRunnerConfiguration): - loader_file_format: TLoaderFileFormat = "jsonl" # jsonp or insert commands will be generated pool_type: TPoolType = "process" + destination_capabilities: DestinationCapabilitiesContext = None # injectable + schema_storage_config: SchemaVolumeConfiguration + normalize_storage_config: NormalizeVolumeConfiguration + load_storage_config: LoadVolumeConfiguration diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 0d85b97e85..3ff8f8d6ef 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -36,8 +36,9 @@ class Normalize(Runnable[ProcessPool]): load_package_counter: Counter = None @with_config(spec=NormalizeConfiguration, namespaces=("normalize",)) - def __init__(self, config: NormalizeConfiguration = ConfigValue, collector: CollectorRegistry = REGISTRY, schema_storage: SchemaStorage = None) -> None: + def __init__(self, collector: CollectorRegistry = REGISTRY, schema_storage: SchemaStorage = None, config: NormalizeConfiguration = ConfigValue) -> None: self.config = config + self.loader_file_format = config.destination_capabilities.preferred_loader_file_format self.pool: ProcessPool = None self.normalize_storage: NormalizeStorage = None self.load_storage: LoadStorage = None @@ -46,7 +47,7 @@ def __init__(self, config: NormalizeConfiguration = ConfigValue, collector: Coll # setup storages self.create_storages() # create schema storage with give type - self.schema_storage = schema_storage or SchemaStorage(makedirs=True) + self.schema_storage = schema_storage or SchemaStorage(self.config.schema_storage_config, makedirs=True) try: self.create_gauges(collector) except ValueError as v: @@ -62,9 +63,10 @@ def create_gauges(registry: CollectorRegistry) -> None: Normalize.load_package_counter = Gauge("normalize_load_packages_created_count", "Count of load package created", ["schema"], registry=registry) def create_storages(self) -> None: - self.normalize_storage = NormalizeStorage(True) + # pass initial normalize storage config embedded in normalize config + self.normalize_storage = NormalizeStorage(True, config=self.config.normalize_storage_config) # normalize saves in preferred format but can read all supported formats - self.load_storage = LoadStorage(True, self.config.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + self.load_storage = LoadStorage(True, self.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, config=self.config.load_storage_config) @staticmethod @@ -159,7 +161,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TM # TODO: maybe we should chunk by file size, now map all files to workers chunk_files = [files] schema_dict = schema.to_dict() - config_tuple = (self.normalize_storage.config, self.load_storage.config, self.config.loader_file_format, schema_dict) + config_tuple = (self.normalize_storage.config, self.load_storage.config, self.loader_file_format, schema_dict) param_chunk = [(*config_tuple, load_id, files) for files in chunk_files] processed_chunks = self.pool.starmap(Normalize.w_normalize_files, param_chunk) return sum([t[1] for t in processed_chunks]), [t[0] for t in processed_chunks], chunk_files @@ -168,7 +170,7 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMap processed_chunk = Normalize.w_normalize_files( self.normalize_storage.config, self.load_storage.config, - self.config.loader_file_format, + self.loader_file_format, schema.to_dict(), load_id, files diff --git a/dlt/pipeline/exceptions.py b/dlt/pipeline/exceptions.py index a250b31047..88ffe4c3c4 100644 --- a/dlt/pipeline/exceptions.py +++ b/dlt/pipeline/exceptions.py @@ -50,8 +50,8 @@ def __init__(self, method: str) -> None: class SqlClientNotAvailable(PipelineException): - def __init__(self, client_type: str) -> None: - super().__init__(f"SQL Client not available in {client_type}") + def __init__(self, destination_name: str) -> None: + super().__init__(f"SQL Client not available for {destination_name}") class InvalidIteratorException(PipelineException): diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index b365aded5a..e055717312 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -11,7 +11,7 @@ from dlt.common import json, sleep, signals, logger from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner -from dlt.common.configuration import make_configuration +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import PoolRunnerConfiguration from dlt.common.storages import FileStorage from dlt.common.schema import Schema @@ -46,7 +46,7 @@ def __init__(self, pipeline_name: str, log_level: str = "INFO") -> None: self._loader_instance: Load = None # patch config and initialize pipeline - self.C = make_configuration(PoolRunnerConfiguration(), initial_value={ + self.C = resolve_configuration(PoolRunnerConfiguration(), initial_value={ "PIPELINE_NAME": pipeline_name, "LOG_LEVEL": log_level, "POOL_TYPE": "None", diff --git a/experiments/pipeline/__init__.py b/experiments/pipeline/__init__.py index 9db28049a4..a334c8c90b 100644 --- a/experiments/pipeline/__init__.py +++ b/experiments/pipeline/__init__.py @@ -1,15 +1,10 @@ -# from experiments.pipeline.pipeline import Pipeline - -# pipeline = Pipeline() - -# def __getattr__(name): -# if name == 'y': -# return 3 -# raise AttributeError(f"module '{__name__}' has no attribute '{name}'") import tempfile +from typing import Union +from importlib import import_module from dlt.common.typing import TSecretValue, Any from dlt.common.configuration import with_config +from dlt.load.client_base import DestinationReference from experiments.pipeline.configuration import PipelineConfiguration from experiments.pipeline.pipeline import Pipeline @@ -30,11 +25,17 @@ @with_config(spec=PipelineConfiguration, auto_namespace=True) -def configure(pipeline_name: str = None, working_dir: str = None, pipeline_secret: TSecretValue = None, **kwargs: Any) -> Pipeline: +def configure(pipeline_name: str = None, working_dir: str = None, pipeline_secret: TSecretValue = None, destination: Union[None, str, DestinationReference] = None, **kwargs: Any) -> Pipeline: + print(locals()) + print(kwargs["_last_dlt_config"].pipeline_name) # if working_dir not provided use temp folder if not working_dir: working_dir = tempfile.gettempdir() - return Pipeline(pipeline_name, working_dir, pipeline_secret, kwargs["runtime"]) + # if destination is a str, get destination reference by dynamically importing module from known location + if isinstance(destination, str): + destination = import_module(f"dlt.load.{destination}") + + return Pipeline(pipeline_name, working_dir, pipeline_secret, destination, kwargs["runtime"]) def run() -> Pipeline: return configure().extract() \ No newline at end of file diff --git a/experiments/pipeline/configuration.py b/experiments/pipeline/configuration.py index 02109abf50..2f374524ff 100644 --- a/experiments/pipeline/configuration.py +++ b/experiments/pipeline/configuration.py @@ -1,18 +1,34 @@ -from typing import Optional +from typing import ClassVar, Optional, TYPE_CHECKING +from typing_extensions import runtime from dlt.common.configuration import configspec -from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration - +from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration, ContainerInjectableContext from dlt.common.typing import TSecretValue from dlt.common.utils import uniq_id +from experiments.pipeline.typing import TPipelineState + @configspec class PipelineConfiguration(BaseConfiguration): + pipeline_name: Optional[str] = None working_dir: Optional[str] = None pipeline_secret: Optional[TSecretValue] = None runtime: RunConfiguration def check_integrity(self) -> None: - if self.pipeline_secret: - self.pipeline_secret = uniq_id() + if not self.pipeline_secret: + self.pipeline_secret = TSecretValue(uniq_id()) + if not self.pipeline_name: + self.pipeline_name = self.runtime.pipeline_name + + +@configspec(init=True) +class StateInjectableContext(ContainerInjectableContext): + state: TPipelineState + + can_create_default: ClassVar[bool] = False + + if TYPE_CHECKING: + def __init__(self, state: TPipelineState = None) -> None: + ... diff --git a/experiments/pipeline/credentials.py b/experiments/pipeline/credentials.py deleted file mode 100644 index 1513e06f42..0000000000 --- a/experiments/pipeline/credentials.py +++ /dev/null @@ -1,22 +0,0 @@ - -from typing import Any, Sequence, Type - -# gets credentials in namespace (ie pipeline name), grouped under key with spec -# spec can be a class, TypedDict or dataclass. overwrites initial_values -def get_credentials(spec: Type[Any] = None, key: str = None, namespace: str = None, initial_values: Any = None) -> Any: - # will use registered credential providers for all values in spec or return all values under key - pass - - -def get_config(spec: Type[Any], key: str = None, namespace: str = None, initial_values: Any = None) -> Any: - # uses config providers (env, .dlt/config.toml) - # in case of TSecretValues fallbacks to using credential providers - pass - - -class ConfigProvider: - def get(name: str) -> Any: - pass - - def list(prefix: str = None) -> Sequence[str]: - pass diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py index faa642be06..e97a8666ec 100644 --- a/experiments/pipeline/pipeline.py +++ b/experiments/pipeline/pipeline.py @@ -1,81 +1,75 @@ import os -from collections import abc -import tempfile from contextlib import contextmanager from copy import deepcopy from functools import wraps -from typing import Any, List, Iterable, Iterator, Mapping, NewType, Optional, Sequence, Type, TypedDict, Union, overload -from operator import itemgetter -from prometheus_client import REGISTRY +from typing import Any, Callable, ClassVar, List, Iterable, Iterator, Generator, Mapping, NewType, Optional, Sequence, Tuple, Type, TypedDict, Union, get_type_hints, overload from dlt.common import json, logger, signals -from dlt.common.sources import DLT_METADATA_FIELD, with_table_name +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext +from dlt.common.runners.runnable import Runnable +from dlt.common.sources import DLT_METADATA_FIELD, TResolvableDataItem, with_table_name from dlt.common.typing import DictStrAny, StrAny, TFun, TSecretValue, TAny from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner -from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages import LiveSchemaStorage, NormalizeStorage -from dlt.common.configuration import make_configuration, RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, ProductionNormalizeVolumeConfiguration +from dlt.common.configuration import inject_namespace +from dlt.common.configuration.specs import RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, LoadVolumeConfiguration, PoolRunnerConfiguration, DestinationCapabilitiesContext from dlt.common.schema.schema import Schema from dlt.common.storages.file_storage import FileStorage -from dlt.common.utils import is_interactive, uniq_id +from dlt.common.utils import is_interactive +from dlt.extract.extract import ExtractorStorage, extract -from dlt.extract.extractor_storage import ExtractorStorageBase -from dlt.load.typing import TLoaderCapabilities -from dlt.normalize.configuration import configuration as normalize_configuration from dlt.normalize import Normalize -from dlt.load.client_base import SqlClientBase, SqlJobClientBase -from dlt.load.configuration import LoaderClientDwhConfiguration, configuration as loader_configuration +from dlt.load.client_base import DestinationReference, JobClientBase, SqlClientBase +from dlt.load.configuration import DestinationClientConfiguration, DestinationClientDwhConfiguration, LoaderConfiguration from dlt.load import Load +from dlt.normalize.configuration import NormalizeConfiguration -from experiments.pipeline.configuration import get_config -from experiments.pipeline.exceptions import PipelineConfigMissing, PipelineConfiguredException, MissingDependencyException, PipelineStepFailed -from dlt.extract.sources import DltSource, TResolvableDataItem +from experiments.pipeline.exceptions import PipelineConfigMissing, MissingDependencyException, PipelineStepFailed +from dlt.extract.sources import DltResource, DltSource, TTableSchemaTemplate +from experiments.pipeline.typing import TPipelineStep, TPipelineState +from experiments.pipeline.configuration import StateInjectableContext -TConnectionString = NewType("TConnectionString", str) -TSourceState = NewType("TSourceState", DictStrAny) +class Pipeline: -TCredentials = Union[TConnectionString, StrAny] + STATE_FILE: ClassVar[str] = "state.json" + STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineState).keys()) -class TPipelineState(TypedDict): pipeline_name: str - default_dataset: str - # is_transient: bool - default_schema_name: Optional[str] - # pipeline_secret: TSecretValue - destination_name: Optional[str] - # schema_sync_path: Optional[str] - - -# class TPipelineState() -# sources: Dict[str, TSourceState] - - -class Pipeline: + dataset_name: str + default_schema_name: str + working_dir: str - STATE_FILE = "state.json" - - def __init__(self, pipeline_name: str, working_dir: str, pipeline_secret: TSecretValue, runtime: RunConfiguration): - self.pipeline_name = pipeline_name - self.working_dir = working_dir + def __init__(self, pipeline_name: str, working_dir: str, pipeline_secret: TSecretValue, destination: DestinationReference, runtime: RunConfiguration): self.pipeline_secret = pipeline_secret self.runtime_config = runtime + self.destination = destination self.root_folder: str = None - # self._initial_values: DictStrAny = {} - self._state: TPipelineState = {} + self._container = Container() + self._state: TPipelineState = {} # type: ignore self._pipeline_storage: FileStorage = None - self._extractor_storage: ExtractorStorageBase = None self._schema_storage: LiveSchemaStorage = None + # self._pool_config: PoolRunnerConfiguration = None + self._schema_storage_config: SchemaVolumeConfiguration = None + self._normalize_storage_config: NormalizeVolumeConfiguration = None + self._load_storage_config: LoadVolumeConfiguration = None + + initialize_runner(self.runtime_config) + self._configure(pipeline_name, working_dir) def with_state_sync(f: TFun) -> TFun: @wraps(f) def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - with self._managed_state(): - return f(self, *args, **kwargs) + # backup and restore state + with self._managed_state() as state: + # add the state to container as a context + with self._container.injectable_context(StateInjectableContext(state=state)): + return f(self, *args, **kwargs) return _wrap @@ -90,11 +84,33 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: return _wrap + def with_config_namespace(namespaces: Tuple[str, ...]) -> Callable[[TFun], TFun]: + + def decorator(f: TFun) -> TFun: + + @wraps(f) + def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + # add namespace context to the container to be used by all configuration without explicit namespaces resolution + with inject_namespace(ConfigNamespacesContext(pipeline_name=self.pipeline_name, namespaces=namespaces)): + return f(self, *args, **kwargs) + + return _wrap + + return decorator + @with_state_sync - def _configure(self) -> None: + def _configure(self, pipeline_name: str, working_dir: str) -> None: + self.pipeline_name = pipeline_name + self.working_dir = working_dir + # compute the folder that keeps all of the pipeline state FileStorage.validate_file_name_component(self.pipeline_name) self.root_folder = os.path.join(self.working_dir, self.pipeline_name) + # create default configs + # self._pool_config = PoolRunnerConfiguration(is_single_run=True, exit_on_exception=True) + self._schema_storage_config = SchemaVolumeConfiguration(schema_volume_path = os.path.join(self.root_folder, "schemas")) + self._normalize_storage_config = NormalizeVolumeConfiguration(normalize_volume_path=os.path.join(self.root_folder, "normalize")) + self._load_storage_config = LoadVolumeConfiguration(load_volume_path=os.path.join(self.root_folder, "load"),) # create pipeline working dir self._pipeline_storage = FileStorage(self.root_folder, makedirs=False) @@ -106,147 +122,153 @@ def _configure(self) -> None: self._create_pipeline() # create schema storage - self._schema_storage = LiveSchemaStorage(makedirs=True) - # create extractor storage - self._extractor_storage = ExtractorStorageBase( - "1.0.0", - True, - FileStorage(os.path.join(self.root_folder, "extract"), makedirs=True), - self._ensure_normalize_storage() - ) + self._schema_storage = LiveSchemaStorage(self._schema_storage_config, makedirs=True) - initialize_runner(self.CONFIG) - - def drop() -> "Pipeline": + def drop(self) -> "Pipeline": """Deletes existing pipeline state, schemas and drops datasets at the destination if present""" pass - def _get_config(self, spec: Type[TAny], accept_partial: bool = False) -> Type[TAny]: - print(self._initial_values) - return make_configuration(spec, spec, initial_values=self._initial_values, accept_partial=accept_partial) - + # @overload + # def extract( + # self, + # data: Union[Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], + # table_name = None, + # write_disposition = None, + # parent = None, + # columns = None, + # max_parallel_data_items: int = 20, + # schema: Schema = None + # ) -> None: + # ... + + # @overload + # def extract( + # self, + # data: DltSource, + # max_parallel_iterators: int = 1, + # max_parallel_data_items: int = 20, + # schema: Schema = None + # ) -> None: + # ... - @overload - def extract( - self, - data: Union[Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], - table_name = None, - write_disposition = None, - parent = None, - columns = None, - max_parallel_data_items: int = 20, - schema: Schema = None - ) -> None: - ... - - @overload - def extract( - self, - data: DltSource, - max_parallel_iterators: int = 1, - max_parallel_data_items: int = 20, - schema: Schema = None - ) -> None: - ... - - @maybe_default_config @with_schemas_sync @with_state_sync + @with_config_namespace(("extract",)) def extract( self, - data: Union[DltSource, Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], + data: Union[DltSource, DltResource, Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], table_name = None, write_disposition = None, parent = None, columns = None, - max_parallel_iterators: int = 1, - max_parallel_data_items: int = 20, - schema: Schema = None + schema: Schema = None, + *, + max_parallel_items: int = 100, + workers: int = 5 ) -> None: - self._schema_storage.save_schema(schema) - self._state["default_schema_name"] = schema.name - # TODO: apply hints to table - - # check if iterator or iterable is supported - # if isinstance(items, str) or isinstance(items, dict) or not - # TODO: check if schema exists - with self._managed_state(): - default_table_name = table_name or self.CONFIG.pipeline_name - # TODO: this is not very effective - we consume iterator right away, better implementation needed where we stream iterator to files directly - all_items: List[DictStrAny] = [] - for item in data: - # dispatch items by type - if callable(item): - item = item() - if isinstance(item, dict): - all_items.append(item) - elif isinstance(item, abc.Sequence): - all_items.extend(item) - # react to CTRL-C and shutdowns from controllers - signals.raise_if_signalled() - - try: - self._extract_iterator(default_table_name, all_items) - except Exception: - raise PipelineStepFailed("extract", self.last_run_exception, runner.LAST_RUN_METRICS) - - # @maybe_default_config - # @with_schemas_sync - # @with_state_sync - # def extract_many() -> None: - # pass + + def only_data_args(with_schema: bool) -> None: + if not table_name or not write_disposition or not parent or not columns: + raise InvalidExtractArguments(with_schema) + if not with_schema and not schema: + raise InvalidExtractArguments(with_schema) + + def choose_schema() -> Schema: + if schema: + return schema + if self.default_schema_name: + return self.default_schema + return Schema(self.pipeline_name) + + source: DltSource = None + + if isinstance(data, DltSource): + # already a source + only_data_args(with_schema=False) + source = data + elif isinstance(data, DltResource): + # package resource in source + only_data_args(with_schema=True) + source = DltSource(choose_schema(), [data]) + else: + table_schema: TTableSchemaTemplate = { + "name": table_name, + "parent": parent, + "write_disposition": write_disposition, + "columns": columns + } + # convert iterable to resource + data = DltResource.from_data(data, name=table_name, table_schema_template=table_schema) + # wrap resource in source + source = DltSource(choose_schema(), [data]) + + try: + self._extract_source(source, max_parallel_items, workers) + except Exception as exc: + raise PipelineStepFailed("extract", self.last_run_exception, runner.LAST_RUN_METRICS) from exc + @with_schemas_sync - def normalize(self, dry_run: bool = False, workers: int = 1, max_events_in_chunk: int = 100000) -> None: + @with_config_namespace(("normalize",)) + def normalize(self, workers: int = 1, dry_run: bool = False) -> None: if is_interactive() and workers > 1: raise NotImplementedError("Do not use normalize workers in interactive mode ie. in notebook") - # set parameters to be passed to config - normalize = self._configure_normalize({ - "WORKERS": workers, - "POOL_TYPE": "thread" if workers == 1 else "process" - }) - try: - ec = runner.run_pool(normalize.config, normalize) - # in any other case we raise if runner exited with status failed - if runner.LAST_RUN_METRICS.has_failed: - raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) - return ec - except Exception as r_ex: - # if EXIT_ON_EXCEPTION flag is set, exception will bubble up directly - raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) from r_ex - finally: - signals.raise_if_signalled() + + # get destination capabilities + destination_caps = self._get_destination_capabilities() + # create default normalize config + normalize_config = NormalizeConfiguration( + is_single_run=True, + exit_on_exception=True, + workers=workers, + pool_type="none" if workers == 1 else "process", + schema_storage_config=self._schema_storage_config, + normalize_storage_config=self._normalize_storage_config, + load_storage_config=self._load_storage_config + ) + # run with destination context + with self._container.injectable_context(destination_caps): + # shares schema storage with the pipeline so we do not need to install + normalize = Normalize(config=normalize_config, schema_storage=self._schema_storage) + self._run_step_in_pool("normalize", normalize, normalize.config) @with_schemas_sync @with_state_sync + @with_config_namespace(("load",)) def load( self, - destination_name: str = None, - default_dataset: str = None, - credentials: TCredentials = None, - raise_on_failed_jobs = False, - raise_on_incompatible_schema = False, - always_drop_dataset = False, - dry_run: bool = False, - max_parallel_loads: int = 20, - normalize_workers: int = 1 + destination: DestinationReference = None, + dataset_name: str = None, + credentials: Any = None, + # raise_on_failed_jobs = False, + # raise_on_incompatible_schema = False, + # always_drop_dataset = False, + *, + workers: int = 20 ) -> None: - self._resolve_load_client_config() - # check if anything to normalize - if len(self._extractor_storage.normalize_storage.list_files_to_normalize_sorted()) > 0: - self.normalize(dry_run=dry_run, workers=normalize_workers) - # then load - print(locals()) - load = self._configure_load(locals(), credentials) - runner.run_pool(load.config, load) - if runner.LAST_RUN_METRICS.has_failed: - raise PipelineStepFailed("load", self.last_run_exception, runner.LAST_RUN_METRICS) - - def activate(self) -> None: - # make this instance the active one - pass + + # set destination and default dataset if provided + self.destination = destination or self.destination + self.dataset_name = dataset_name or self.dataset_name + # check if any schema is present, if not then no data was extracted + if not self.default_schema_name: + return + + # make sure that destination is set and client is importable and can be instantiated + client_initial_config = self._get_destination_client_initial_config(credentials) + self._get_destination_client(self.default_schema, client_initial_config) + + # create initial loader config and the loader + load_config = LoaderConfiguration( + is_single_run=True, + exit_on_exception=True, + workers=workers, + load_storage_config=self._load_storage_config + ) + load = Load(self.destination, is_storage_owner=False, config=load_config, initial_client_config=client_initial_config) + self._run_step_in_pool("load", load, load.config) @property def schemas(self) -> Mapping[str, Schema]: @@ -254,7 +276,7 @@ def schemas(self) -> Mapping[str, Schema]: @property def default_schema(self) -> Schema: - return self.schemas[self._state.get("default_schema_name")] + return self.schemas[self.default_schema_name] @property def last_run_exception(self) -> BaseException: @@ -266,150 +288,101 @@ def _create_pipeline(self) -> None: def _restore_pipeline(self) -> None: self._restore_state() - def _ensure_normalize_storage(self) -> NormalizeStorage: - return NormalizeStorage(True, self._get_config(NormalizeVolumeConfiguration)) - - def _configure_normalize(self, initial_values: DictStrAny) -> Normalize: - destination_name = self._ensure_destination_name() - format = self._get_loader_capabilities(destination_name)["preferred_loader_file_format"] - # create normalize config - initial_values.update({ - "LOADER_FILE_FORMAT": format, - "ADD_EVENT_JSON": False - }) - # apply schema storage config - # initial_values.update(self._schema_storage.C.as_dict()) - # apply common initial settings - initial_values.update(self._initial_values) - C = normalize_configuration(initial_values=initial_values) - print(C.as_dict()) - # shares schema storage with the pipeline so we do not need to install - return Normalize(C, schema_storage=self._schema_storage) - - def _configure_load(self, loader_initial: DictStrAny, credentials: TCredentials = None) -> Load: - # get destination or raise - destination_name = self._ensure_destination_name() - # import load client for given destination or raise - self._get_loader_capabilities(destination_name) - # get default dataset or raise - default_dataset = self._ensure_default_dataset() - - loader_initial.update({ - "DELETE_COMPLETED_JOBS": True, - "CLIENT_TYPE": destination_name - }) - loader_initial.update(self._initial_values) - - loader_client_initial = { - "DEFAULT_DATASET": default_dataset, - "DEFAULT_SCHEMA_NAME": self._state.get("default_schema_name") - } - if credentials: - loader_client_initial.update(credentials) - - C = loader_configuration(initial_values=loader_initial) - return Load(C, REGISTRY, client_initial_values=loader_client_initial, is_storage_owner=False) - - def _set_common_initial_values(self) -> None: - self._initial_values.update({ - "IS_SINGLE_RUN": True, - "EXIT_ON_EXCEPTION": True, - "LOAD_VOLUME_PATH": os.path.join(self.root_folder, "load"), - "NORMALIZE_VOLUME_PATH": os.path.join(self.root_folder, "normalize"), - "SCHEMA_VOLUME_PATH": os.path.join(self.root_folder, "schemas") - }) - - def _get_loader_capabilities(self, destination_name: str) -> TLoaderCapabilities: + def _restore_state(self) -> None: + self._state.clear() # type: ignore + restored_state: TPipelineState = json.loads(self._pipeline_storage.load(Pipeline.STATE_FILE)) + self._state.update(restored_state) + + def _extract_source(self, source: DltSource, max_parallel_items: int, workers: int) -> None: + storage = ExtractorStorage(self._normalize_storage_config) + + for _, partials in extract(source, storage, max_parallel_items=max_parallel_items, workers=workers).items(): + for partial in partials: + source.schema.update_schema(source.schema.normalize_table_identifiers(partial)) + + # save schema and set as default if this is first one + self._schema_storage.save_schema(source.schema) + if not self.default_schema_name: + self.default_schema_name = source.schema.name + + def _run_step_in_pool(self, step: TPipelineStep, runnable: Runnable[Any], config: PoolRunnerConfiguration) -> int: try: - return Load.loader_capabilities(destination_name) + ec = runner.run_pool(config, runnable) + # in any other case we raise if runner exited with status failed + if runner.LAST_RUN_METRICS.has_failed: + raise PipelineStepFailed(step, self.last_run_exception, runner.LAST_RUN_METRICS) + return ec + except Exception as r_ex: + # if EXIT_ON_EXCEPTION flag is set, exception will bubble up directly + raise PipelineStepFailed(step, self.last_run_exception, runner.LAST_RUN_METRICS) from r_ex + finally: + signals.raise_if_signalled() + + def _get_destination_client_initial_config(self, credentials: Any) -> DestinationClientConfiguration: + if not self.destination: + raise PipelineConfigMissing( + "destination", + "load", + "Please provide `destination` argument to `config` or `load` method or via pipeline config file or environment var." + ) + dataset_name = self._get_dataset_name() + # create initial destination client config + client_spec = self.destination.spec() + if issubclass(client_spec, DestinationClientDwhConfiguration): + # client support schemas and datasets + return client_spec(dataset_name=dataset_name, default_schema_name=self.default_schema_name, credentials=credentials) + else: + return client_spec(credentials=credentials) + + def _get_destination_client(self, schema: Schema, initial_config: DestinationClientConfiguration = None) -> JobClientBase: + try: + return self.destination.client(schema, initial_config) except ImportError: + client_spec = self.destination.spec() raise MissingDependencyException( - f"{destination_name} destination", - [f"python-dlt[{destination_name}]"], + f"{client_spec.destination_name} destination", + [f"python-dlt[{client_spec.destination_name}]"], "Dependencies for specific destinations are available as extras of python-dlt" ) - def _resolve_load_client_config(self) -> Type[LoaderClientDwhConfiguration]: - return get_config( - LoaderClientDwhConfiguration, - initial_values={ - "client_type": self._initial_values.get("destination_name"), - "default_dataset": self._initial_values.get("default_dataset") - }, - accept_partial=True - ) - - def _ensure_destination_name(self) -> str: - d_n = self._resolve_load_client_config().client_type - if not d_n: + def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: + if not self.destination: raise PipelineConfigMissing( - "destination_name", + "destination", "normalize", - "Please provide `destination_name` argument to `config` or `load` method or via pipeline config file or environment var." + "Please provide `destination` argument to `config` or `load` method or via pipeline config file or environment var." ) - return d_n - - def _ensure_default_dataset(self) -> str: - d_n = self._resolve_load_client_config().default_dataset - if not d_n: - d_n = normalize_schema_name(self.CONFIG.pipeline_name) - return d_n + return self.destination.capabilities() - def _extract_iterator(self, default_table_name: str, items: Sequence[DictStrAny]) -> None: - try: - for idx, i in enumerate(items): - if not isinstance(i, dict): - # TODO: convert non dict types into dict - items[idx] = i = {"v": i} - if DLT_METADATA_FIELD not in i or i.get(DLT_METADATA_FIELD, None) is None: - # set default table name - with_table_name(i, default_table_name) - - load_id = uniq_id() - self._extractor_storage.save_json(f"{load_id}.json", items) - self._extractor_storage.commit_events( - self.default_schema.name, - self._extractor_storage.storage.make_full_path(f"{load_id}.json"), - default_table_name, - len(items), - load_id - ) - - runner.LAST_RUN_METRICS = TRunMetrics(was_idle=False, has_failed=False, pending_items=0) - except Exception as ex: - logger.exception("extracting iterator failed") - runner.LAST_RUN_METRICS = TRunMetrics(was_idle=False, has_failed=True, pending_items=0) - runner.LAST_RUN_EXCEPTION = ex - raise + def _get_dataset_name(self) -> str: + return self.dataset_name or self.pipeline_name @contextmanager - def _managed_state(self) -> Iterator[None]: + def _managed_state(self) -> Iterator[TPipelineState]: + # write props to pipeline variables + for prop in Pipeline.STATE_PROPS: + setattr(self, prop, self._state.get(prop)) + # backup the state backup_state = deepcopy(self._state) try: - yield + yield self._state except Exception: # restore old state - self._state.clear() + self._state.clear() # type: ignore self._state.update(backup_state) raise else: + # update state props + for prop in Pipeline.STATE_PROPS: + self._state[prop] = getattr(self, prop) + # compare backup and new state, save only if different + new_state = json.dumps(self._state) + old_state = json.dumps(backup_state) # persist old state - # TODO: compare backup and new state, save only if different - self._pipeline_storage.save(Pipeline.STATE_FILE, json.dumps(self._state)) - - def _restore_state(self) -> None: - self._state.clear() - restored_state: DictStrAny = json.loads(self._pipeline_storage.load(Pipeline.STATE_FILE)) - self._state.update(restored_state) - - @property - def is_active(self) -> bool: - return id(self) == id(Pipeline.ACTIVE_INSTANCE) + if new_state != old_state: + self._pipeline_storage.save(Pipeline.STATE_FILE, new_state) @property def has_pending_loads(self) -> bool: # TODO: check if has pending normalizer and loader data pass - -# active instance always present -Pipeline.ACTIVE_INSTANCE = Pipeline() diff --git a/experiments/pipeline/typing.py b/experiments/pipeline/typing.py index cb06d5a97f..9fbf917265 100644 --- a/experiments/pipeline/typing.py +++ b/experiments/pipeline/typing.py @@ -1,14 +1,16 @@ -from typing import Literal +from typing import Literal, TypedDict, Optional TPipelineStep = Literal["extract", "normalize", "load"] +class TPipelineState(TypedDict): + pipeline_name: str + dataset_name: str + default_schema_name: Optional[str] + # destination_name: Optional[str] -# class TTableSchema(TTableSchema, total=False): -# name: Optional[str] -# description: Optional[str] -# write_disposition: Optional[TWriteDisposition] -# table_sealed: Optional[bool] -# parent: Optional[str] -# filters: Optional[TRowFilters] -# columns: TTableSchemaColumns \ No newline at end of file + +# TSourceState = NewType("TSourceState", DictStrAny) + +# class TPipelineState() +# sources: Dict[str, TSourceState] diff --git a/tests/common/runners/test_runners.py b/tests/common/runners/test_runners.py index 220e62dd58..35ac059371 100644 --- a/tests/common/runners/test_runners.py +++ b/tests/common/runners/test_runners.py @@ -5,7 +5,7 @@ from dlt.cli import TRunnerArgs from dlt.common import signals -from dlt.common.configuration import make_configuration, configspec +from dlt.common.configuration import resolve_configuration, configspec from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType from dlt.common.exceptions import DltException, SignalReceivedException, TimeRangeExhaustedException, UnsupportedProcessStartMethodException from dlt.common.runners import pool_runner as runner @@ -46,12 +46,12 @@ class ThreadPoolConfiguration(ModPoolRunnerConfiguration): def configure(C: Type[PoolRunnerConfiguration], args: TRunnerArgs) -> PoolRunnerConfiguration: - return make_configuration(C(), initial_value=args._asdict()) + return resolve_configuration(C(), initial_value=args._asdict()) @pytest.fixture(scope="module", autouse=True) def logger_autouse() -> None: - init_logger(ModPoolRunnerConfiguration) + init_logger() @pytest.fixture(autouse=True) diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index f8e39ae099..37ab37630e 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -2,7 +2,7 @@ import pytest from dlt.common import pendulum -from dlt.common.configuration import make_configuration +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.exceptions import DictValidationException from dlt.common.schema.typing import TColumnName, TSimpleRegex, COLUMN_HINTS @@ -21,7 +21,7 @@ @pytest.fixture def schema_storage() -> SchemaStorage: - C = make_configuration( + C = resolve_configuration( SchemaVolumeConfiguration(), initial_value={ "import_schema_path": "tests/common/cases/schemas/rasa", diff --git a/tests/common/storages/test_loader_storage.py b/tests/common/storages/test_loader_storage.py index 84a938d663..980364946e 100644 --- a/tests/common/storages/test_loader_storage.py +++ b/tests/common/storages/test_loader_storage.py @@ -6,7 +6,7 @@ from dlt.common import sleep from dlt.common.schema import Schema from dlt.common.storages.load_storage import LoadStorage, TParsedJobFileName -from dlt.common.configuration import make_configuration +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import LoadVolumeConfiguration from dlt.common.storages.exceptions import NoMigrationPathException from dlt.common.typing import StrAny @@ -17,14 +17,14 @@ @pytest.fixture def storage() -> LoadStorage: - C = make_configuration(LoadVolumeConfiguration()) - s = LoadStorage(True, C, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + C = resolve_configuration(LoadVolumeConfiguration()) + s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS, C) return s def test_complete_successful_package(storage: LoadStorage) -> None: # should delete package in full - storage.delete_completed_jobs = True + storage.config.delete_completed_jobs = True load_id, file_name = start_loading_file(storage, [{"content": "a"}, {"content": "b"}]) assert storage.storage.has_folder(storage.get_package_path(load_id)) storage.complete_job(load_id, file_name) @@ -35,7 +35,7 @@ def test_complete_successful_package(storage: LoadStorage) -> None: assert not storage.storage.has_folder(storage.get_completed_package_path(load_id)) # do not delete completed jobs - storage.delete_completed_jobs = False + storage.config.delete_completed_jobs = False load_id, file_name = start_loading_file(storage, [{"content": "a"}, {"content": "b"}]) storage.complete_job(load_id, file_name) storage.complete_load_package(load_id) @@ -47,7 +47,7 @@ def test_complete_successful_package(storage: LoadStorage) -> None: def test_complete_package_failed_jobs(storage: LoadStorage) -> None: # loads with failed jobs are always persisted - storage.delete_completed_jobs = True + storage.config.delete_completed_jobs = True load_id, file_name = start_loading_file(storage, [{"content": "a"}, {"content": "b"}]) assert storage.storage.has_folder(storage.get_package_path(load_id)) storage.fail_job(load_id, file_name, "EXCEPTION") @@ -142,22 +142,22 @@ def test_process_schema_update(storage: LoadStorage) -> None: def test_full_migration_path() -> None: # create directory structure - s = LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) # overwrite known initial version write_version(s.storage, "1.0.0") # must be able to migrate to current version - s = LoadStorage(False, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(False, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) assert s.version == LoadStorage.STORAGE_VERSION def test_unknown_migration_path() -> None: # create directory structure - s = LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) # overwrite known initial version write_version(s.storage, "10.0.0") # must be able to migrate to current version with pytest.raises(NoMigrationPathException): - LoadStorage(False, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + LoadStorage(False, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) def start_loading_file(s: LoadStorage, content: Sequence[StrAny]) -> Tuple[str, str]: diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 95cb9406f8..10d4ccaf6f 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -8,7 +8,7 @@ from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema from dlt.common.schema.utils import default_normalizers -from dlt.common.configuration import make_configuration +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError from dlt.common.storages import SchemaStorage, LiveSchemaStorage, FileStorage @@ -37,7 +37,7 @@ def ie_storage() -> SchemaStorage: def init_storage(C: SchemaVolumeConfiguration) -> SchemaStorage: # use live schema storage for test which must be backward compatible with schema storage s = LiveSchemaStorage(C, makedirs=True) - assert C is s.C + assert C is s.config if C.export_schema_path: os.makedirs(C.export_schema_path, exist_ok=True) if C.import_schema_path: @@ -85,7 +85,7 @@ def test_skip_import_if_not_modified(synced_storage: SchemaStorage, storage: Sch # the import schema gets modified storage_schema.tables["_dlt_loads"]["write_disposition"] = "append" storage_schema.tables.pop("event_user") - synced_storage._export_schema(storage_schema, synced_storage.C.export_schema_path) + synced_storage._export_schema(storage_schema, synced_storage.config.export_schema_path) # now load will import again reloaded_schema = synced_storage.load_schema("ethereum") # we have overwritten storage schema @@ -112,7 +112,7 @@ def test_store_schema_tampered(synced_storage: SchemaStorage, storage: SchemaSto def test_schema_export(ie_storage: SchemaStorage) -> None: schema = Schema("ethereum") - fs = FileStorage(ie_storage.C.export_schema_path) + fs = FileStorage(ie_storage.config.export_schema_path) exported_name = ie_storage._file_name_in_store("ethereum", "yaml") # no exported schema assert not fs.has_file(exported_name) @@ -191,7 +191,7 @@ def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: assert schema.version_hash == schema_hash assert schema._imported_version_hash == "njJAySgJRs2TqGWgQXhP+3pCh1A1hXcqe77BpM7JtOU=" # we have simple schema in export folder - fs = FileStorage(ie_storage.C.export_schema_path) + fs = FileStorage(ie_storage.config.export_schema_path) exported_name = ie_storage._file_name_in_store("ethereum", "yaml") exported_schema = yaml.safe_load(fs.load(exported_name)) assert schema.version_hash == exported_schema["version_hash"] @@ -205,7 +205,7 @@ def test_save_store_schema_over_import_sync(synced_storage: SchemaStorage) -> No synced_storage.save_schema(schema) assert schema._imported_version_hash == "njJAySgJRs2TqGWgQXhP+3pCh1A1hXcqe77BpM7JtOU=" # import schema is overwritten - fs = FileStorage(synced_storage.C.import_schema_path) + fs = FileStorage(synced_storage.config.import_schema_path) exported_name = synced_storage._file_name_in_store("ethereum", "yaml") exported_schema = yaml.safe_load(fs.load(exported_name)) assert schema.version_hash == exported_schema["version_hash"] == schema_hash diff --git a/tests/dbt_runner/test_runner_bigquery.py b/tests/dbt_runner/test_runner_bigquery.py index bed8b2304b..c11e9f992e 100644 --- a/tests/dbt_runner/test_runner_bigquery.py +++ b/tests/dbt_runner/test_runner_bigquery.py @@ -9,7 +9,7 @@ from dlt.dbt_runner.utils import DBTProcessingError from dlt.dbt_runner import runner -from dlt.load.bigquery.client import BigQuerySqlClient +from dlt.load.bigquery.bigquery import BigQuerySqlClient from tests.utils import add_config_to_env, init_logger, preserve_environ from tests.dbt_runner.utils import setup_runner diff --git a/tests/dbt_runner/test_runner_redshift.py b/tests/dbt_runner/test_runner_redshift.py index d7e4d42b15..8e650602df 100644 --- a/tests/dbt_runner/test_runner_redshift.py +++ b/tests/dbt_runner/test_runner_redshift.py @@ -4,7 +4,7 @@ from prometheus_client import CollectorRegistry from dlt.common import logger -from dlt.common.configuration import make_configuration +from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import PostgresCredentials from dlt.common.storages import FileStorage from dlt.common.telemetry import TRunMetrics, get_metrics_from_prometheus @@ -14,7 +14,7 @@ from dlt.dbt_runner.utils import DBTProcessingError from dlt.dbt_runner.configuration import DBTRunnerConfiguration from dlt.dbt_runner import runner -from dlt.load.redshift.client import RedshiftSqlClient +from dlt.load.redshift.redshift import RedshiftSqlClient from tests.utils import add_config_to_env, clean_test_storage, init_logger, preserve_environ from tests.dbt_runner.utils import modify_and_commit_file, load_secret, setup_runner @@ -61,13 +61,13 @@ def module_autouse() -> None: def test_configuration() -> None: # check names normalized - C = make_configuration( + C = resolve_configuration( DBTRunnerConfiguration(), initial_value={"PACKAGE_REPOSITORY_SSH_KEY": "---NO NEWLINE---", "SOURCE_SCHEMA_PREFIX": "schema"} ) assert C.package_repository_ssh_key == "---NO NEWLINE---\n" - C = make_configuration( + C = resolve_configuration( DBTRunnerConfiguration(), initial_value={"PACKAGE_REPOSITORY_SSH_KEY": "---WITH NEWLINE---\n", "SOURCE_SCHEMA_PREFIX": "schema"} ) diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 883a8b44dc..f9c2abbe6e 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -6,10 +6,13 @@ from multiprocessing.dummy import Pool as ThreadPool from dlt.common import json +from dlt.common.configuration.specs.destination_capabilities_context import TLoaderFileFormat from dlt.common.utils import uniq_id from dlt.common.typing import StrAny from dlt.common.schema import TDataType from dlt.common.storages import NormalizeStorage, LoadStorage +from dlt.common.configuration.specs import DestinationCapabilitiesContext +from dlt.common.configuration.container import Container from dlt.extract.extract import ExtractorStorage from dlt.normalize import Normalize @@ -35,10 +38,13 @@ def rasa_normalize() -> Normalize: def init_normalize(default_schemas_path: str = None) -> Normalize: clean_test_storage() - with TEST_DICT_CONFIG_PROVIDER.values({"import_schema_path": default_schemas_path, "external_schema_format": "json"}): - n = Normalize(collector=CollectorRegistry()) - # set jsonl as default writer - n.load_storage.loader_file_format = n.config.loader_file_format = "jsonl" + # pass schema config fields to schema storage via dict config provider + with TEST_DICT_CONFIG_PROVIDER().values({"import_schema_path": default_schemas_path, "external_schema_format": "json"}): + # inject the destination capabilities + with Container().injectable_context(DestinationCapabilitiesContext(preferred_loader_file_format="jsonl")): + n = Normalize(collector=CollectorRegistry()) + + assert n.load_storage.loader_file_format == n.loader_file_format == "jsonl" return n @@ -73,7 +79,8 @@ def test_normalize_single_user_event_jsonl(raw_normalize: Normalize) -> None: def test_normalize_single_user_event_insert(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.loader_file_format = raw_normalize.config.loader_file_format = "insert_values" + mock_destination_caps(raw_normalize, "insert_values") + raw_normalize.load_storage.loader_file_format = raw_normalize.loader_file_format = "insert_values" expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) # verify values line for expected_table in expected_tables: @@ -129,7 +136,7 @@ def test_preserve_slot_complex_value_json_l(rasa_normalize: Normalize) -> None: def test_preserve_slot_complex_value_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.config.loader_file_format = "insert_values" + mock_destination_caps(rasa_normalize, "insert_values") load_id = normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_slot"], 2) @@ -152,7 +159,7 @@ def test_normalize_raw_type_hints(rasa_normalize: Normalize) -> None: def test_normalize_many_events_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.config.loader_file_format = "insert_values" + mock_destination_caps(rasa_normalize, "insert_values") load_id = normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) @@ -175,7 +182,7 @@ def test_normalize_many_events(rasa_normalize: Normalize) -> None: def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.loader_file_format = rasa_normalize.config.loader_file_format = "insert_values" + mock_destination_caps(rasa_normalize, "insert_values") extract_cases( rasa_normalize.normalize_storage, ["event.event.many_load_2", "event.event.user_load_1", "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2"] @@ -204,7 +211,7 @@ def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: def test_normalize_typed_json(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.loader_file_format = raw_normalize.config.loader_file_format = "jsonl" + mock_destination_caps(raw_normalize, "jsonl") extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], "special", "special") raw_normalize.run(ThreadPool(processes=1)) loads = raw_normalize.load_storage.list_packages() @@ -291,3 +298,9 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType) event_schema = load_storage.load_package_schema(loads[0]) # in raw normalize timestamp column must not be coerced to timestamp assert event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type + + +def mock_destination_caps(n: Normalize, loader_file_format: TLoaderFileFormat) -> None: + # mock the loader file format + # TODO: mock full capabilities here + n.load_storage.loader_file_format = n.loader_file_format = loader_file_format diff --git a/tests/utils.py b/tests/utils.py index 56c76a642d..b586b8ef3b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,9 +8,9 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.providers import EnvironProvider, DictionaryProvider -from dlt.common.configuration.resolve import make_configuration, serialize_value +from dlt.common.configuration.resolve import resolve_configuration, serialize_value from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration -from dlt.common.configuration.specs.config_providers_configuration import ConfigProvidersListConfiguration +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersListContext from dlt.common.logger import init_logging_from_config from dlt.common.storages import FileStorage from dlt.common.schema import Schema @@ -21,10 +21,14 @@ TEST_STORAGE_ROOT = "_storage" # add test dictionary provider -TEST_DICT_CONFIG_PROVIDER = DictionaryProvider() -providers_config = Container()[ConfigProvidersListConfiguration] -providers_config.providers.append(TEST_DICT_CONFIG_PROVIDER) - +def TEST_DICT_CONFIG_PROVIDER(): + providers_context = Container()[ConfigProvidersListContext] + try: + return providers_context.get_provider(DictionaryProvider.NAME) + except KeyError: + provider = DictionaryProvider() + providers_context.add_provider(provider) + return provider class MockHttpResponse(): @@ -67,7 +71,7 @@ def preserve_environ() -> None: def init_logger(C: RunConfiguration = None) -> None: if not hasattr(logging, "health"): if not C: - C = make_configuration(RunConfiguration()) + C = resolve_configuration(RunConfiguration()) init_logging_from_config(C) From 47277d61f5e66e2b9b945c2e9081a8630b90ea8a Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 20 Oct 2022 11:37:02 +0200 Subject: [PATCH 40/66] adding pseudo code samples for api v2: create pipeline, working with credentials --- experiments/pipeline/create_pipeline.md | 119 -------------- .../pipeline/examples/create_pipeline.md | 130 +++++++++++++++ .../pipeline/examples/general_usage.md | 0 .../pipeline/examples/project_structure.md | 0 .../pipeline/examples/secrets_and_config.md | 155 ++++++++++++++++++ 5 files changed, 285 insertions(+), 119 deletions(-) delete mode 100644 experiments/pipeline/create_pipeline.md create mode 100644 experiments/pipeline/examples/create_pipeline.md create mode 100644 experiments/pipeline/examples/general_usage.md create mode 100644 experiments/pipeline/examples/project_structure.md create mode 100644 experiments/pipeline/examples/secrets_and_config.md diff --git a/experiments/pipeline/create_pipeline.md b/experiments/pipeline/create_pipeline.md deleted file mode 100644 index 058331271c..0000000000 --- a/experiments/pipeline/create_pipeline.md +++ /dev/null @@ -1,119 +0,0 @@ -## Mockup code for generic template credentials - -This is a toml file, for BigQuery credentials destination and instruction how to add source credentials. - -I assume that new pipeline is accessing REST API - -```toml -# provide credentials to `taktile` source below, for example -# api_key = "api key to access taktile endpoint" - -[gcp_credentials] -client_email = -private_key = -project_id = -``` - -## Mockup code for taktile credentials - -```toml -taktile_api_key="96e6m3/OFSumLRG9mnIr" - -[gcp_credentials] -client_email = -private_key = -project_id = -``` - - -## Mockup code pipeline script template with nice UX - -This is a template made for BiqQuery destination and the source named `taktile`. This already proposes a nice structure for the code so the pipeline may be developed further. - - -```python -import requests -import dlt - -# The code below is an example of well structured pipeline -# @Ty if you want I can write more comments and explanations - -@dlt.source -def taktile_data(): - # retrieve credentials via DLT secrets - api_key = dlt.secrets["api_key"] - - # make a call to the endpoint with request library - resp = requests.get("https://example.com/data", headers={"Authorization": api_key"}) - resp.raise_for_status() - data = resp.json() - - # you may process the data here - - # return resource to be loaded into `data` table - return dlt.resource(data, name="data") - -dlt.run(taktile_data(), destination="bigquery") -``` - - -## Mockup code of taktile pipeline script with nice UX - -Example for the simplest ad hoc pipeline without any structure - -```python -import requests -import dlt - -resp = requests.get( - "https://taktile.com/api/v2/logs", - headers={"Authorization": dlt.secrets["taktile_api_key"]}) -resp.raise_for_status() -data = resp.json() - -dlt.run(data["result"], name="logs", destination="bigquery") -``` - -Example for endpoint returning only one resource: - -```python -import requests -import dlt - -@dlt.source -def taktile_data(): - resp = requests.get( - "https://taktile.com/api/v2/logs", - headers={"Authorization": dlt.secrets["taktile_api_key"]}) - resp.raise_for_status() - data = resp.json() - - return dlt.resource(data["result"], name="logs") - -dlt.run(taktile_data(), destination="bigquery") -``` - -With two resources: - -```python -import requests -import dlt - -@dlt.source -def taktile_data(): - resp = requests.get( - "https://taktile.com/api/v2/logs", - headers={"Authorization": dlt.secrets["taktile_api_key"]}) - resp.raise_for_status() - logs = resp.json()["results"] - - resp = requests.get( - "https://taktile.com/api/v2/decisions", - headers={"Authorization": dlt.secrets["taktile_api_key"]}) - resp.raise_for_status() - decisions = resp.json()["results"] - - return dlt.resource(logs, name="logs"), dlt.resource(decisions, name="decisions") - -dlt.run(taktile_data(), destination="bigquery") -``` diff --git a/experiments/pipeline/examples/create_pipeline.md b/experiments/pipeline/examples/create_pipeline.md new file mode 100644 index 0000000000..03592228cb --- /dev/null +++ b/experiments/pipeline/examples/create_pipeline.md @@ -0,0 +1,130 @@ + +## Example for the simplest ad hoc pipeline without any structure + +```python +import requests +import dlt +from dlt.destinations import bigquery + +resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=1", + headers={"Authorization": dlt.secrets["taktile_api_key"]}) +resp.raise_for_status() +data = resp.json() + +dlt.run(data["result"], name="logs", destination=bigquery) +``` + +## Example for endpoint returning only one resource: + +```python +import requests +import dlt + +# it will use function name `taktile_data` to name the source and schema +@dlt.source +def taktile_data(initial_log_id, taktile_api_key): + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + data = resp.json() + + # yes you can return a list of values and it will work + return dlt.resource(data["result"], name="logs") + +taktile_data(1).run(destination=bigquery) +# this below also works +# dlt.run(taktile_data(1), destination=bigquery) +``` + +## With two resources: +also shows how to select just one resource to be loaded + +```python +import requests +import dlt + +@dlt.source +def taktile_data(initial_log_id, taktile_api_key): + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + logs = resp.json()["results"] + + resp = requests.get( + "https://taktile.com/api/v2/decisions%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + decisions = resp.json()["results"] + + return dlt.resource(logs, name="logs"), dlt.resource(decisions, name="decisions", write_disposition="replace") + +# load all resources +taktile_data(1).run(destination=bigquery) +# load only decisions +taktile_data(1).select("decisions").run(....) +``` +note: +`dlt.resource` takes all the parameters (ie. `write_disposition` or `columns` that let you define the table schema fully) + +**alternative form which uses iterators** for very long responses that for example use HTTP chunked: + +```python +import requests +import dlt + +@dlt.source +def taktile_data(initial_log_id, taktile_api_key): + + # it will use the function name `logs` to name the resource/table + # yield the data which is really long jsonl stream + @dlt.resource + def logs(): + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + stream=True, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + for line in resp.text(): + yield json.loads(line) + + # here we provide name and write_disposition directly + @dlt.resource(name="decisions", write_disposition="replace") + def decisions_reader(): + resp = requests.get( + "https://taktile.com/api/v2/decisions%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + return resp.json()["results"] + + return logs, decisions_reader +``` + +## With pipeline state and incremental load + + +from_log_id = dlt.state.get("from_log_id") or initial_log_id +```python +import requests +import dlt + +# it will use function name `taktile_data` to name the source and schema +@dlt.source +def taktile_data(initial_log_id, taktile_api_key): + from_log_id = dlt.state.get("from_log_id") or initial_log_id + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + data = resp.json() + + # write state before returning data + + # yes you can return a list of values and it will work + yield dlt.resource(data["result"], name="logs") + + +taktile_data(1).run(destination=bigquery) + diff --git a/experiments/pipeline/examples/general_usage.md b/experiments/pipeline/examples/general_usage.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/experiments/pipeline/examples/project_structure.md b/experiments/pipeline/examples/project_structure.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/experiments/pipeline/examples/secrets_and_config.md b/experiments/pipeline/examples/secrets_and_config.md new file mode 100644 index 0000000000..032184530a --- /dev/null +++ b/experiments/pipeline/examples/secrets_and_config.md @@ -0,0 +1,155 @@ +## Example +How config values and secrets are handled should promote good behavior + +1. secret values should never be present in the pipeline code +2. config values can be provided, changed etc. when pipeline is deployed +3. still it must be easy and intuitive + +For the source extractor function below (reads selected tab from google sheets) we can pass config values in following ways: + +```python + +import dlt +from dlt.destinations import bigquery + + +@dlt.source +def google_sheets(spreadsheet_id, tab_names, credentials, only_strings=False): + sheets = build('sheets', 'v4', credentials=Services.from_json(credentials)) + tabs = [] + for tab_name in tab_names: + data = sheets.get(spreadsheet_id, tab_name).execute().values() + tabs.append(dlt.resource(data, name=tab_name)) + return tabs + +# WRONG: provide all values directly - wrong but possible. secret values should never be present in the code! +google_sheets("23029402349032049", ["tab1", "tab2"], credentials={"private_key": ""}).run(destination=bigquery) + +# OPTION A: provide config values directly and secrets via automatic injection mechanism (see later) +# `credentials` value will be provided by the `source` decorator +# `spreadsheet_id` and `tab_names` take default values from the arguments below but may be overwritten by the decorator via config providers (see later) +google_sheets("23029402349032049", ["tab1", "tab2"]).run(destination=bigquery) + + +# OPTION B: all values are injected so there are no defaults and config values must be present in the providers +google_sheets().run(destination=bigquery) + + +# OPTION C: we use `dlt.secrets` and `dlt.config` to explicitly take those values from providers in the way we control (not recommended but straightforward) +google_sheets(dlt.config["sheet_id"], dlt.config["tabs"], dlt.secrets["gcp_credentials"]).run(destination=bigquery) +``` + +## Injection mechanism +By the magic of @dlt.source decorator + +The signature of the function `google_sheets` is also defining the structure of the configuration and secrets. + +When `google_sheets` function is called the decorator takes every input parameter and uses its value as initial. +Then it looks into `providers` if the value is not overwritten there. +It does the same for all arguments that were not in the call but are specified in function signature. +Then it calls the original function with updated input arguments thus passing config and secrets to it. + +## Providers +When config or secret values are needed, `dlt` looks for them in providers. In case of `google_sheets()` it will always look for: `spreadsheet_id`, `tab_names`, `credentials` and `strings_only`. + +Providers form a hierarchy. At the top are environment variables, then `secrets.toml` and `config.toml` files. Providers like google, aws, azure vaults can be inserted after the environment provider. +For example if `spreadsheet_id` is in environemtn, dlt does not look into other provieers. + +The values passed in the code directly are the lowest in provider hierarchy. + +## Namespaces +Config and secret values can be grouped in namespaces. Easiest way to visualize it is via `toml` files. + +This is valid for OPTION A and OPTION B + +**secrets.toml** +```toml +client_email = +private_key = +project_id = +``` +**config.toml** +```toml +spreadsheet_id="302940230490234903294" +tab_names=["tab1", "tab2"] +``` + +**alternative secrets.toml** +**secrets.toml** +```toml +[credentials] +client_email = +private_key = +project_id = +``` + +where `credentials` is name of the parameter from `google_sheet`. This parameter is a namespace for keys it contains and namespace are *optional* + +For OPTION C user uses its own custom keys to get credentials so: +**secrets.toml** +```toml +[gcp_credentials] +client_email = +private_key = +project_id = +``` +**config.toml** +```toml +sheet_id="302940230490234903294" +tabs=["tab1", "tab2"] +``` + +But what about `bigquery` credentials? In the case above it will reuse the credentials from **secrets.toml** (in OPTION A and B) but what if we need different credentials? + +Dlt has a nice optional namespace structure to handle all conflicts. It becomes useful in advanced cases like above. The secrets and config files may look as follows (and they will work with OPTION A and B) + +**secrets.toml** +```toml +[source.credentials] +client_email = +private_key = +project_id = + +[destination.credentials] +client_email = +private_key = +project_id = + +``` +**config.toml** +```toml +[source] +spreadsheet_id="302940230490234903294" +tab_names=["tab1", "tab2"] +``` + +How namespaces work in environment variables? they are prefix for the key so to get `spreadsheet_id` `dlt` will look for + +`SOURCE__SPREADSHEET_ID` first and `SPREADSHEET_ID` second + +## Interesting / Advanced stuff. + +The approach above makes configs and secrets explicit and autogenerates required lookups. It lets me for example **generate deployments** and **code templates for pipeline scripts** automatically because I know what are the config parameters and I have total control over users code and final values via the decorator. + +There's more cool stuff + +Here's how professional source function should look like + +```python + + +@dlt.source +def google_sheets(spreadsheet_id: str, tab_names: List[str], credentials: TCredentials, only_strings=False): + sheets = build('sheets', 'v4', credentials=Services.from_json(credentials)) + tabs = [] + for tab_name in tab_names: + data = sheets.get(spreadsheet_id, tab_name).execute().values() + tabs.append(dlt.resource(data, name=tab_name)) + return tabs +``` + +Here I provide typing so I can type check injected values so no crap gets passed to the function. + +I also tell which argument is secret via `TCredentials` that let's me control for the case when user is putting secret values in `config.toml` or some other unsafe provider (and generate even better templates) + +We could go even deeper here (ie. configurations `spec` may be explicitly declared via python `dataclasses`, may be embedded in one another etc. -> it comes useful when writing something really complicated) \ No newline at end of file From 6b65451387a91cdcc4a1409b2243d5aa28fab395 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 23 Oct 2022 20:56:06 +0200 Subject: [PATCH 41/66] adds more code samples --- .../pipeline/examples/create_pipeline.md | 332 ++++++++++++++---- .../pipeline/examples/general_usage.md | 10 + .../examples/last_value_with_state.md | 26 ++ ...istep_pipelines_and_dependent_resources.md | 0 .../pipeline/examples/stream_resources.md | 4 + 5 files changed, 307 insertions(+), 65 deletions(-) create mode 100644 experiments/pipeline/examples/last_value_with_state.md create mode 100644 experiments/pipeline/examples/multistep_pipelines_and_dependent_resources.md create mode 100644 experiments/pipeline/examples/stream_resources.md diff --git a/experiments/pipeline/examples/create_pipeline.md b/experiments/pipeline/examples/create_pipeline.md index 03592228cb..1f486e04b3 100644 --- a/experiments/pipeline/examples/create_pipeline.md +++ b/experiments/pipeline/examples/create_pipeline.md @@ -1,5 +1,9 @@ ## Example for the simplest ad hoc pipeline without any structure +It is still possible to create "intuitive" pipeline without much knowledge except how to import dlt engine and how to import the destination. + +No decorators and secret files, configurations are necessary. We should probably not teach that but I want this kind of super basic and brute code to still work + ```python import requests @@ -8,123 +12,321 @@ from dlt.destinations import bigquery resp = requests.get( "https://taktile.com/api/v2/logs?from_log_id=1", - headers={"Authorization": dlt.secrets["taktile_api_key"]}) + headers={"Authorization": "98217398ahskj92982173"}) resp.raise_for_status() data = resp.json() -dlt.run(data["result"], name="logs", destination=bigquery) +# if destination or name are not provided, an exception will raise that explains +# 1. where and why to put the name of the table +# 2. how to import the destination and how to configure it with credentials in a proper way +# nevertheless the user decided to pass credentials directly +dlt.run(data["result"], name="logs", destination=bigquery(Service.credentials_from_file("service.json"))) ``` +## Source extractor function the preferred way +General guidelines: +1. the source extractor is a function decorated with `@dlt.source`. that function yields or returns a list of resources. **it should not access the data itself**. see the example below +2. resources are generator functions that always **yield** data (I think I will enforce that by raising exception). Access to external endpoints, databases etc. should happen from that generator function. Generator functions may be decorated with `@dlt.resource` to provide alternative names, write disposition etc. +3. resource generator functions can be OFC parametrized and resources may be created dynamically +4. the resource generator function may yield a single dict or list of dicts + +> my dilemma here is if I should allow to access data directly in the source function ie. to discover schema or get some configuration for the resources from some endpoint. it is very easy to avoid that but for the non-programmers it will not be intuitive. + ## Example for endpoint returning only one resource: ```python import requests import dlt -# it will use function name `taktile_data` to name the source and schema +# the `dlt.source` tell the library that the decorated function is a source +# it will use function name `taktile_data` to name the source and the generated schema by default +# in general `@source` should **return** a list of resources or list of generators (function that yield data) +# @source may also **yield** resources or generators - if yielding is more convenient +# if @source returns or yields data - this will generate exception with a proper explanation. dlt user can always load the data directly without any decorators like in the previous example! @dlt.source def taktile_data(initial_log_id, taktile_api_key): - resp = requests.get( - "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - data = resp.json() - # yes you can return a list of values and it will work - return dlt.resource(data["result"], name="logs") + # the `dlt.resource` tells the `dlt.source` that the function defines a resource + # will use function name `logs` as resource/table name by default + # the function should **yield** the data items one by one or **yield** a list. + # here the decorator is optional: there are no parameters to `dlt.resource` + @dlt.resource + def logs(): + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + # option 1: yield the whole list + yield resp.json()["result"] + # or -> this is useful if you deal with a stream of data and for that you need an API that supports that, for example you could yield lists containing paginated results + for item in resp.json()["result"]: + yield item + + # as mentioned we return a resource or a list of resources + return logs + # this will also work + return logs() + +# now load the data taktile_data(1).run(destination=bigquery) # this below also works -# dlt.run(taktile_data(1), destination=bigquery) +# dlt.run(source=taktile_data(1), destination=bigquery) ``` -## With two resources: -also shows how to select just one resource to be loaded +**Remarks:** +1. the **@dlt.resource** let's you define the table schema hints: `name`, `write_disposition`, `parent`, `columns` +2. the **@dlt.source** let's you define global schema props: `name` (which is also source name), `schema` which is Schema object if explicit schema is provided `nesting` to set nesting level etc. (I do not have a signature now - I'm still working on it) +3. decorators can also be used as functions ie in case of dlt.resource and `lazy_function` (see one page below) ```python -import requests -import dlt +endpoints = ["songs", "playlist", "albums"] +# return list of resourced +return [dlt.resource(lazy_function(endpoint, name=endpoint) for endpoint in endpoints)] -@dlt.source +``` + +**What if we remove logs() function and get data in source body** + +Yeah definitely possible. Just replace `@source` with `@resource` decorator and remove the function + +```python +@dlt.resource(name="logs", write_disposition="append") def taktile_data(initial_log_id, taktile_api_key): + + # yes, this will also work but data will be obtained immediately when taktile_data() is called. resp = requests.get( "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, headers={"Authorization": taktile_api_key}) resp.raise_for_status() - logs = resp.json()["results"] + for item in resp.json()["result"]: + yield item - resp = requests.get( - "https://taktile.com/api/v2/decisions%i" % initial_log_id, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - decisions = resp.json()["results"] +# this will load the resource into default schema. see `general_usage.md) +dlt.run(source=taktile_data(1), destination=bigquery) - return dlt.resource(logs, name="logs"), dlt.resource(decisions, name="decisions", write_disposition="replace") +``` + +**The power of decorators** + +With decorators dlt can inspect and modify the code being decorated. +1. it knows what are the sources and resources without running them +2. it knows input arguments so it knows the config values and secret values (see `secrets_and_config`). with those we can generate deployments automatically +3. it can inject config and secret values automatically +4. it wraps the functions into objects that provide additional functionalities +- sources and resources are iterators so you can write +```python +items = list(source()) + +for item in source()["logs"]: + ... +``` +- you can select which resources to load with `source().select(*names)` +- you can add mappings and filters to resources + +## The power of yielding: The preferred way to write resources + +The Python function that yields is not a function but magical object that `dlt` can control: +1. it is not executed when you call it! the call just creates a generator (see below). in the example above `taktile_data(1)` will not execute the code inside, it will just return an object composed of function code and input parameters. dlt has control over the object and can execute the code later. this is called `lazy execution` +2. i can control when and how much of the code is executed. the function that yields typically looks like that + +```python +def lazy_function(endpoint_name): + # INIT - this will be executed only once when DLT wants! + get_configuration() + from_item = dlt.state.get("last_item", 0) + l = get_item_list_from_api(api_key, endpoint_name) + + # ITERATOR - this will be executed many times also when DLT wants more data! + for item in l: + yield requests.get(url, api_key, "%s?id=%s" % (endpoint_name, item["id"])).json() + # CLEANUP + # this will be executed only once after the last item was yielded! + dlt.state["last_item"] = item["id"] +``` + +3. dlt will execute this generator in extractor. the whole execution is atomic (including writing to state). if anything fails with exception the whole extract function fails. +4. the execution can be parallelized by using a decorator or a simple modifier function ie: +```python +for item in l: + yield deferred(requests.get(url, api_key, "%s?id=%s" % (endpoint_name, item["id"])).json()) +``` + +## Python data transformations + +```python +from dlt.secrets import anonymize + +def transform_user(user_data): + # anonymize creates nice deterministic hash for any hashable data type + user_data["user_id"] = anonymize(user_data["user_id"]) + user_data["user_email"] = anonymize(user_data["user_email"]) + return user_data + +# usage: can be applied in the source +@dlt.source +def hubspot(...): + ... + + @dlt.resource(write_disposition="replace") + def users(): + ... + users = requests.get(...) + # option 1: just map and yield from mapping + users = map(transform_user, users) + ... + yield users, deals, customers + + # option 2: return resource with chained transformation + return users.map(transform_user) + +# option 3: user of the pipeline determines if s/he wants the anonymized data or not and does it in pipeline script. so the source may offer also transformations that are easily used +hubspot(...)["users"].map(transform_user) +hubspot.run(...) + +``` + +## Multiple resources and resource selection +The source extraction function may contain multiple resources. The resources can be defined as multiple resource functions or created dynamically ie. with parametrized generators. +The user of the pipeline can check what resources are available and select the resources to load. + + +**each resource has a a separate resource function** +```python +import requests +import dlt + +@dlt.source +def hubspot(...): + + @dlt.resource(write_disposition="replace") + def users(): + # calls to API happens here + ... + yield users + + @dlt.resource(write_disposition="append") + def transactions(): + ... + yield transactions + + # return a list of resources + return users, transactions # load all resources taktile_data(1).run(destination=bigquery) # load only decisions taktile_data(1).select("decisions").run(....) ``` -note: -`dlt.resource` takes all the parameters (ie. `write_disposition` or `columns` that let you define the table schema fully) -**alternative form which uses iterators** for very long responses that for example use HTTP chunked: +**resources are created dynamically** +Here we implement a single parametrized function that **yields** data and we call it repeatedly. Mind that the function body won't be executed immediately, only later when generator is consumed in extract stage. ```python -import requests -import dlt @dlt.source -def taktile_data(initial_log_id, taktile_api_key): +def spotify(): - # it will use the function name `logs` to name the resource/table - # yield the data which is really long jsonl stream - @dlt.resource - def logs(): - resp = requests.get( - "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, - stream=True, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - for line in resp.text(): - yield json.loads(line) + endpoints = ["songs", "playlists", "albums"] - # here we provide name and write_disposition directly - @dlt.resource(name="decisions", write_disposition="replace") - def decisions_reader(): - resp = requests.get( - "https://taktile.com/api/v2/decisions%i" % initial_log_id, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - return resp.json()["results"] + def get_resource(endpoint): + # here we yield the whole response + yield requests.get(url + "/" + endpoint).json() + + # here we yield resources because this produces cleaner code + for endpoint in endpoints: + # calling get_resource creates generator, the actual code of the function will be executed in extractor + yield dlt.resource(get_resource(endpoint), name=endpoint) - return logs, decisions_reader ``` -## With pipeline state and incremental load +**resources are created dynamically from a single document** +Here we have a list of huge documents and we want to split it into several tables. We do not want to rely on `dlt` normalize stage to do it for us for some reason... +This also provides an example of why getting data in the source function (and not within the resource function) is discouraged. -from_log_id = dlt.state.get("from_log_id") or initial_log_id ```python -import requests -import dlt -# it will use function name `taktile_data` to name the source and schema @dlt.source -def taktile_data(initial_log_id, taktile_api_key): - from_log_id = dlt.state.get("from_log_id") or initial_log_id - resp = requests.get( - "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - data = resp.json() +def spotify(): - # write state before returning data + # get the data in source body and the simply return the resources + # this is discouraged because data access + list_of_huge_docs = requests.get(...) - # yes you can return a list of values and it will work - yield dlt.resource(data["result"], name="logs") + return dlt.resource(list_of_huge_docs["songs"], name="songs"), dlt.resource(list_of_huge_docs["playlists"], name="songs") +# the call to get the resource list happens outside the `dlt` pipeline, this means that if there's +# exception in `list_of_huge_docs = requests.get(...)` I cannot handle or log it (or send slack message) +# user must do it himself or the script will be simply killed. not so much problem during development +# but may be a problem after deployment. +spotify().run(...) +``` -taktile_data(1).run(destination=bigquery) +How to prevent that: +```python +@dlt.source +def spotify(): + + list_of_huge_docs = None + + def get_data(name): + # regarding intuitiveness and cleanliness of the code this is a hack/trickery IMO + # this will also have consequences if execution is parallelized + nonlocal list_of_huge_docs + docs = list_of_huge_docs or list_of_huge_docs = requests.get(...) + yield docs[name] + + return dlt.resource(get_data("songs"), name="songs"), dlt.resource(get_data("playlists"), name="songs") +``` + +The other way to prevent that (see also `multistep_pipelines_and_dependent_resources.md`) + +```python +@dlt.source +def spotify(): + + @dlt.resource + def get_huge_doc(name): + yield requests.get(...) + + # make songs and playlists to be dependent on get_huge_doc + @dlt.resource(depends_on=huge_doc) + def songs(huge_doc): + yield huge_doc["songs"] + + @dlt.resource(depends_on=huge_doc) + def playlists(huge_doc): + yield huge_doc["playlists"] + + # as you can see the get_huge_doc is not even returned, nevertheless it will be evaluated (only once) + # the huge doc will not be extracted and loaded + return songs, playlists +``` + +> I could also implement lazy evaluation of the @dlt.source function. this is a lot of trickery in the code but definitely possible. there are consequences though: if someone requests lists of resources or the initial schema in the pipeline script before `run` method the function body will be evaluated. It is really hard to make intuitive code to work properly. + +## Pipeline with multiple sources or with same source called twice + +Here our source is parametrized or we have several sources to be extracted. This is more or less Ty's twitter case. + +```python +@dlt.source +def mongo(from_id, to_id, credentials): + ... + +@dlt.source +def labels(): + ... + + +# option 1: at some point I may parallelize execution of sources if called this way +dlt.run(source=[mongo(0, 100000), mongo(100001, 200000), labels()], destination=bigquery) +# option 2: be explicit (this has consequences: read the `run` method in `general_usage`) +p = dlt.pipeline(destination=bigquery) +p.extract(mongo(0, 100000)) +p.extract(mongo(100001, 200000)) +p.extract(labels()) +p.normalize() +p.load() diff --git a/experiments/pipeline/examples/general_usage.md b/experiments/pipeline/examples/general_usage.md index e69de29bb2..9cd82a167d 100644 --- a/experiments/pipeline/examples/general_usage.md +++ b/experiments/pipeline/examples/general_usage.md @@ -0,0 +1,10 @@ +## importing + +## running and `run` function + + +## ad hoc and configured pipelines + + + +## the default schema and the default data set \ No newline at end of file diff --git a/experiments/pipeline/examples/last_value_with_state.md b/experiments/pipeline/examples/last_value_with_state.md new file mode 100644 index 0000000000..b90d629f72 --- /dev/null +++ b/experiments/pipeline/examples/last_value_with_state.md @@ -0,0 +1,26 @@ + +## With pipeline state and incremental load + + +from_log_id = dlt.state.get("from_log_id") or initial_log_id +```python +import requests +import dlt + +# it will use function name `taktile_data` to name the source and schema +@dlt.source +def taktile_data(initial_log_id, taktile_api_key): + from_log_id = dlt.state.get("from_log_id") or initial_log_id + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + data = resp.json() + + # write state before returning data + + # yes you can return a list of values and it will work + yield dlt.resource(data["result"], name="logs") + + +taktile_data(1).run(destination=bigquery) \ No newline at end of file diff --git a/experiments/pipeline/examples/multistep_pipelines_and_dependent_resources.md b/experiments/pipeline/examples/multistep_pipelines_and_dependent_resources.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/experiments/pipeline/examples/stream_resources.md b/experiments/pipeline/examples/stream_resources.md new file mode 100644 index 0000000000..b46a2762ac --- /dev/null +++ b/experiments/pipeline/examples/stream_resources.md @@ -0,0 +1,4 @@ +Advanced: +There are stream resources that contain many data types ie. RASA Tracker so a single resource may map to many tables: +1. hints can also be functions and lambdas to create dynamic hints based on data items yielded +2. I will re-introduce the `with_table` modifier function from v1 - it is less efficient but more intuitive for the user \ No newline at end of file From b17e513e3b079a8bb4a67c467e5b848cb2e97b88 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 23 Oct 2022 22:29:06 +0200 Subject: [PATCH 42/66] adds general usage samples --- .../pipeline/examples/general_usage.md | 89 ++++++++++++++++++- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/experiments/pipeline/examples/general_usage.md b/experiments/pipeline/examples/general_usage.md index 9cd82a167d..1c8b7c2bf8 100644 --- a/experiments/pipeline/examples/general_usage.md +++ b/experiments/pipeline/examples/general_usage.md @@ -1,10 +1,91 @@ -## importing +## importing dlt +Basic `dlt` functionalities are imported with `import dlt`. Those functionalities are: +1. ability to run the pipeline (which means extract->normalize->load for particular source(s) and destination) with `dlt.run` +2. ability to configure the pipeline ie. provide alternative pipeline name, working directory, folders to import/export schema and various flags: `dlt.pipeline` +3. ability to decorate sources (`dlt.source`) and resources (`dlt.resources`) +4. ability to access secrets `dlt.secrets` and config values `dlt.config` -## running and `run` function +## importing destinations +We support a few built in destinations which may be imported as follows +```python +import dlt +from dlt.destinations import bigquery +from dlt.destinations import redshift +``` +The imported modules may be directly passed to `run` or `pipeline` method. The can be also called to provide credentials and other settings explicitly (discouraged) ie. `bigquery(Service.credentials_from_file("service.json"))` will work. -## ad hoc and configured pipelines +Destinations require `extras` to be installed, if that is not the case, an exception with user friendly message will tell how to do that. +## importing sources +We do not have any structure for the source repository so IDK. For `create pipeline` workflow the source is in the same script as `run` method. -## the default schema and the default data set \ No newline at end of file +## default and explicitly configured pipelines +When the `dlt` is imported a default pipeline is automatically created. That pipeline is configured via configuration providers (ie. `config.toml` or env variables). If no configuration is present, default values will be used. + +1. the name of the pipeline, the name of default schema (if not overridden by the source extractor function) and the default dataset (in destination) are set to **current module name** which in 99% of cases is the name of executing script +2. the working directory of the pipeline will be **OS temporary folder/pipeline name** +3. the logging level will be **INFO** +4. all other configuration options won't be set or will have default values. + +Pipeline can be explicitly created and configured via `dlt.pipeline()` that returns `Pipeline` object. All parameters are optional. If no parameter is provided then default pipeline is returned. Here's a list of options: +1. pipeline_name +2. working_dir +3. pipeline_secret - for deterministic hashing +4. destination - the imported destination module or module name (we accept strings so they can be configured) +5. import_schema_path +6. export_schema_path +7. full_refresh - if set to True all the pipeline working dir and all datasets will be dropped with each run +8. ...any other popular option... give me ideas. maybe `dataset_name`? + +> **Achtung** as per `secrets_and_config.md` the options passed in the code have **lower priority** than any config settings. Example: the pipeline name passed to `dlt.pipeline()` will be overwritten if `pipeline_name` is present in `config.toml` or `PIPELINE_NAME` is in config variables. + + +> It is possible to have several pipelines in a single script if many pipelines are configured via `dlt.pipeline()`. I think we do not want to train people on that so I will skipp the topic. + +## pipeline working directory + + +## the default schema and the default data set + +## running pipelines and `dlt.run` + `@source().run` functions +`dlt.run` + `@source().run` are shortcuts to `Pipeline::run` method on default or last configured (with `dlt.pipeline`) `Pipeline` object. + +The function takes the following parameters +1. source - required - the data to be loaded into destination: a `@dlt.source` or a list of those, a `@dlt.resource` or a list of those, an iterator/generator function or a list of those or iterable (ie. a list) holding something else that iterators. +2. destination +3. dataset +4. table_name, write_disposition etc. - only when data is: a single resource, an iterator (ie. generator function) or iterable (ie. list) +5. schema - a `Schema` instance to be used instead of schema provided by the source or the default schema + +The `run` function works as follows. +1. if there's any pending data to be normalized or loaded, this is done first. +2. only when successful more data is extracted +3. only when successful newly extracted data is normalized and loaded. + +extract / normalize / load are atomic. the `run` is as close to be atomic as possible. + +> `load` is atomic if SQL transformations ie in `dbt` and all the SQL queries take into account only committed `load_ids`. It is certainly possible - we did in for RASA but requires some work... Maybe we implement a fully atomic staging at some point in the loader. + + +## the `Pipeline` object + +## command line interface +I need concept for that. Commands that we need: + +1. `dlt init` to initialize new project and create project template for `create pipeline` use case. Should it also install `extras`? +2. `dlt deploy` to create deployment package (probably cron) + +I have two existing working commands +1. `dlt schema` to load and parse schema file and convert it into `json` or `yaml` +2. `dlt pipeline` to inspect a pipeline with a specified name/working folder + +We may also add: +1. `dlt run` or `dlt schedule` to run a pipeline in a script like cron would. + +## pipeline runtime setup + +1. logging +2. signals +3. unhandled exceptions \ No newline at end of file From f20871e289f57db410f28dec3b1594148fd53334 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 23 Oct 2022 22:35:54 +0200 Subject: [PATCH 43/66] adds example to general usage --- experiments/pipeline/examples/general_usage.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/experiments/pipeline/examples/general_usage.md b/experiments/pipeline/examples/general_usage.md index 1c8b7c2bf8..fb8f0284e6 100644 --- a/experiments/pipeline/examples/general_usage.md +++ b/experiments/pipeline/examples/general_usage.md @@ -71,6 +71,22 @@ extract / normalize / load are atomic. the `run` is as close to be atomic as pos ## the `Pipeline` object +## Examples + +Loads data from `taktile_data` source function into bigquery. All the credentials amd credentials are taken from the config and secret providers. + +Script was run with `python taktile.py` + +```python +from my_taktile_source import taktile_data +from dlt.destinations import bigquery + +# I only want logs from the resources present in taktile_data +taktile_data.run(source=taktile_data(1).select("logs"), destination=bigquery) +``` + +pipeline name is explicitly configured. + ## command line interface I need concept for that. Commands that we need: From 0d5a894694e266688bea3a44d4936bb11b51dd34 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 24 Oct 2022 14:08:46 +0200 Subject: [PATCH 44/66] general usage doc completed + README --- experiments/pipeline/examples/README.md | 10 ++ .../pipeline/examples/create_pipeline.md | 7 + .../pipeline/examples/general_usage.md | 150 ++++++++++++++---- .../pipeline/examples/project_structure.md | 9 ++ .../pipeline/examples/secrets_and_config.md | 4 +- .../pipeline/examples/working_with_schemas.md | 0 6 files changed, 147 insertions(+), 33 deletions(-) create mode 100644 experiments/pipeline/examples/README.md create mode 100644 experiments/pipeline/examples/working_with_schemas.md diff --git a/experiments/pipeline/examples/README.md b/experiments/pipeline/examples/README.md new file mode 100644 index 0000000000..9ef139ce7b --- /dev/null +++ b/experiments/pipeline/examples/README.md @@ -0,0 +1,10 @@ +## Finished documents + +1. [general_usage.md](general_usage.md) +2. [create_pipeline.md](create_pipeline.md) +3. [secrets_and_config.md](secrets_and_config.md) + +## In progress + +1. [project_structure.md](project_structure.md) +2. [working_with_schemas.md](working_with_schemas.md) diff --git a/experiments/pipeline/examples/create_pipeline.md b/experiments/pipeline/examples/create_pipeline.md index 1f486e04b3..61af3b6e84 100644 --- a/experiments/pipeline/examples/create_pipeline.md +++ b/experiments/pipeline/examples/create_pipeline.md @@ -29,6 +29,7 @@ General guidelines: 2. resources are generator functions that always **yield** data (I think I will enforce that by raising exception). Access to external endpoints, databases etc. should happen from that generator function. Generator functions may be decorated with `@dlt.resource` to provide alternative names, write disposition etc. 3. resource generator functions can be OFC parametrized and resources may be created dynamically 4. the resource generator function may yield a single dict or list of dicts +5. like any other iterator, the @dlt.source and @dlt.resource **can be iterated and thus extracted and loaded only once**, see example below. > my dilemma here is if I should allow to access data directly in the source function ie. to discover schema or get some configuration for the resources from some endpoint. it is very easy to avoid that but for the non-programmers it will not be intuitive. @@ -72,6 +73,12 @@ def taktile_data(initial_log_id, taktile_api_key): taktile_data(1).run(destination=bigquery) # this below also works # dlt.run(source=taktile_data(1), destination=bigquery) + +# now to illustrate that each source can be loaded only once, if you run this below +data = taktile_data(1) +data.run(destination=bigquery) # works as expected +data.run(destination=bigquery) # generates empty load package as the data in the iterator is exhausted... maybe I should raise exception instead? + ``` **Remarks:** diff --git a/experiments/pipeline/examples/general_usage.md b/experiments/pipeline/examples/general_usage.md index fb8f0284e6..231960530e 100644 --- a/experiments/pipeline/examples/general_usage.md +++ b/experiments/pipeline/examples/general_usage.md @@ -13,49 +13,96 @@ from dlt.destinations import bigquery from dlt.destinations import redshift ``` -The imported modules may be directly passed to `run` or `pipeline` method. The can be also called to provide credentials and other settings explicitly (discouraged) ie. `bigquery(Service.credentials_from_file("service.json"))` will work. +The imported modules may be directly passed to `run` or `pipeline` method. They can be also called to provide credentials and other settings explicitly (discouraged) ie. `bigquery(Service.credentials_from_file("service.json"))` will bind the credentials to the module. Destinations require `extras` to be installed, if that is not the case, an exception with user friendly message will tell how to do that. ## importing sources -We do not have any structure for the source repository so IDK. For `create pipeline` workflow the source is in the same script as `run` method. +We do not have any structure for the source repository so IDK. For `create pipeline` workflow the source is in the same script as `run` method so the problem does not exist now (?). + +In principle, however, the importable sources are extractor functions so they are imported like any other function. ## default and explicitly configured pipelines -When the `dlt` is imported a default pipeline is automatically created. That pipeline is configured via configuration providers (ie. `config.toml` or env variables). If no configuration is present, default values will be used. +When the `dlt` is imported a default pipeline is automatically created. That pipeline is configured via configuration providers (ie. `config.toml` or env variables - see [secrets_and_config.md](secrets_and_config.md)). If no configuration is present, default values will be used. -1. the name of the pipeline, the name of default schema (if not overridden by the source extractor function) and the default dataset (in destination) are set to **current module name** which in 99% of cases is the name of executing script +1. the name of the pipeline, the name of default schema (if not overridden by the source extractor function) and the default dataset (in destination) are set to **current module name** which in 99% of cases is the name of executing python script 2. the working directory of the pipeline will be **OS temporary folder/pipeline name** 3. the logging level will be **INFO** 4. all other configuration options won't be set or will have default values. -Pipeline can be explicitly created and configured via `dlt.pipeline()` that returns `Pipeline` object. All parameters are optional. If no parameter is provided then default pipeline is returned. Here's a list of options: -1. pipeline_name -2. working_dir -3. pipeline_secret - for deterministic hashing -4. destination - the imported destination module or module name (we accept strings so they can be configured) -5. import_schema_path -6. export_schema_path +Pipeline can be explicitly created and configured via `dlt.pipeline()` that returns `Pipeline` object. All parameters are optional. If no parameter is provided then default pipeline is returned. Here's a list of options. All the options are configurable. +1. pipeline_name - default as above +2. working_dir - default as above +3. pipeline_secret - for deterministic hashing - default is random number +4. destination - the imported destination module or module name (we accept strings so they can be configured) - default is None +5. import_schema_path - default is None +6. export_schema_path - default is None 7. full_refresh - if set to True all the pipeline working dir and all datasets will be dropped with each run 8. ...any other popular option... give me ideas. maybe `dataset_name`? > **Achtung** as per `secrets_and_config.md` the options passed in the code have **lower priority** than any config settings. Example: the pipeline name passed to `dlt.pipeline()` will be overwritten if `pipeline_name` is present in `config.toml` or `PIPELINE_NAME` is in config variables. -> It is possible to have several pipelines in a single script if many pipelines are configured via `dlt.pipeline()`. I think we do not want to train people on that so I will skipp the topic. +> It is possible to have several pipelines in a single script if many pipelines are configured via `dlt.pipeline()`. I think we do not want to train people on that so I will skip the topic. + +## the default schema and the default data set name +`dlt` follows the following rules when auto-generating schemas and naming the dataset to which the data will be loaded. + +**schemas are identified by schema names** + +**default schema** is the first schema that is provided or created within the pipeline. First schema comes in the following ways: +1. From the first extracted `@dlt.source` ie. if you `dlt.run(source=sportify(), ...)` and `spotify` source has schema with name `spotify` attached, it will be used as default schema. +2. it will be created from scratch if you extract a `@dlt.resource` or an iterator ie. list (example: `dlt.run(source=["a", "b", "c"], ...)`) and its name is the pipeline name or generator function name if generator is extracted. (I'm trying to be smart with automatic naming) +3. it is explicitly passed with the `schema` parameter to `run` or `extract` methods - this forces all the sources regardless of the form to place their tables in that schema. + +The **default schema** comes into play when we extract data as in point (2) - without schema information. in that case the default schema is used to attach tables coming from that data + +The pipeline works with multiple schemas. If you extract another source or provide schema explicitly, that schema becomes part of pipeline. Example +```python + +p = dlt.pipeline(dataset="spotify_data_1") +p.extract(source=spotify("me")) # gets schema "spotify" from spotify source, that schema becomes default schema +p.extract(source=echonest("me").select("mel")) # get schema "echonest", all tables belonging to resource "mel" will be placed in that schema +p.extract(source=[label1, label2, label3], name="labels") # will use default schema "spotitfy" for table "labels" +``` + +> learn more on how to work with schemas both via files and programmatically in [working_with_schemas.md](working_with_schemas.md) + +**dataset name** +`dlt` will load data to a specified dataset in the destination. The dataset in case of bigquery is a native dataset, in case of redshift is a native database schema. **One dataset can handle only one schema**. + +There is a default dataset name which is the same as pipeline name. The dataset name can also be explicitly provided into `dlt.pipeline` `dlt.run` and `Pipeline::load` methods. + +In case **there's only default schema** in the pipeline, the data will be loaded into dataset name. Example: `dlt.run(source=spotify("me"), dataset="spotify_data_1")` will load data into dataset `spotify_data_1`) + +In case **there are more schemas in the pipeline**, the data will be loaded into dataset with name `{dataset_name}` for default schema and `{dataset_name}_{schema_name}` for all the other schemas. For the example above: +1. `spotify` tables and `labels` will load into `spotify_data_1` +2. `mel` resource will load into `spotify_data_1_echonest` + + +## pipeline working directory and state +Another fundamental concept is the pipeline working directory. This directory keeps the following information: +1. the extracted data and the load packages with jobs created by normalize +2. the current schemas with all the recent updates +3. the pipeline and source state files. -## pipeline working directory +**Pipeline working directory should be preserved between the runs - if possible** +If the working directory is not preserved: +1. the auto-evolved schema is reset to the initial one. the schema evolution is deterministic so it should not be a problem - just a time wasted to compare schemas with each run +2. if load package is not fully loaded and erased then the destination holds partially loaded and not committed `load_id` +3. the sources that need source state will not load incrementally. -## the default schema and the default data set +This is the situation right now. We could restore working directory from the destination (both schemas and state). Entirely doable (for some destinations) but can't be done right now. ## running pipelines and `dlt.run` + `@source().run` functions -`dlt.run` + `@source().run` are shortcuts to `Pipeline::run` method on default or last configured (with `dlt.pipeline`) `Pipeline` object. +`dlt.run` + `@source().run` are shortcuts to `Pipeline::run` method on default or last configured (with `dlt.pipeline`) `Pipeline` object. Please refer to [create_pipeline.md](create_pipeline.md) for examples. The function takes the following parameters 1. source - required - the data to be loaded into destination: a `@dlt.source` or a list of those, a `@dlt.resource` or a list of those, an iterator/generator function or a list of those or iterable (ie. a list) holding something else that iterators. 2. destination -3. dataset +3. dataset name 4. table_name, write_disposition etc. - only when data is: a single resource, an iterator (ie. generator function) or iterable (ie. list) 5. schema - a `Schema` instance to be used instead of schema provided by the source or the default schema @@ -66,14 +113,44 @@ The `run` function works as follows. extract / normalize / load are atomic. the `run` is as close to be atomic as possible. +the `run` and `load` return information on loaded packages: to which datasets, list of jobs etc. let me think what should be the content + > `load` is atomic if SQL transformations ie in `dbt` and all the SQL queries take into account only committed `load_ids`. It is certainly possible - we did in for RASA but requires some work... Maybe we implement a fully atomic staging at some point in the loader. ## the `Pipeline` object +There are many ways to create or get current pipeline object. +```python + +# create and get default pipeline +p1 = dlt.pipeline() +# create explicitly configured pipeline +p2 = dlt.pipeline(name="pipe", destination=bigquery) +# get recently created pipeline +assert dlt.pipeline() is p2 +# load data with recently created pipeline +assert dlt.run(source=taktile_data()) is p2 +assert taktile_data().run() is p2 + +``` + +The `Pipeline` object provides following functionalities: +1. `run`, `extract`, `normalize` and `load` methods +2. a `pipeline.schema` dictionary-like object to enumerate and get the schemas in pipeline +3. schema get with `pipeline.schemas[name]` is a live object: any modification to it is automatically applied to the pipeline with the next `run`, `load` etc. see [working_with_schemas.md](working_with_schemas.md) +4. it returns `sql_client` and `native_client` to get direct access to the destination (if destination supports SQL - currently all of them do) +5. it has several methods to inspect the pipeline state and I think those should be exposed via `dlt pipeline` CLI + +for example: +- list the extracted files if any +- list the load packages ready to load +- list the failed jobs in package +- show info on destination: what are the datasets, the current load_id, the current schema etc. + ## Examples -Loads data from `taktile_data` source function into bigquery. All the credentials amd credentials are taken from the config and secret providers. +Loads data from `taktile_data` source function into bigquery. All the credentials and configs are taken from the config and secret providers. Script was run with `python taktile.py` @@ -81,27 +158,38 @@ Script was run with `python taktile.py` from my_taktile_source import taktile_data from dlt.destinations import bigquery +# the `run` command below will create default pipeline and use it to load data # I only want logs from the resources present in taktile_data -taktile_data.run(source=taktile_data(1).select("logs"), destination=bigquery) +taktile_data.select("logs").run(destination=bigquery) + +# alternative +dlt.run(source=taktile_data.select("logs")) ``` -pipeline name is explicitly configured. +Explicitly configure schema before the use +```python +import dlt +from dlt.destinations import bigquery + +@dlt.source +def data(api_key): + ... -## command line interface -I need concept for that. Commands that we need: -1. `dlt init` to initialize new project and create project template for `create pipeline` use case. Should it also install `extras`? -2. `dlt deploy` to create deployment package (probably cron) +dlt.pipeline(name="pipe", destination=bigquery, dataset="extract_1") +# use dlt secrets directly to get api key +# no parameters needed to run - we configured destination and dataset already +data(dlt.secrets["api_key"]).run() +``` -I have two existing working commands -1. `dlt schema` to load and parse schema file and convert it into `json` or `yaml` -2. `dlt pipeline` to inspect a pipeline with a specified name/working folder +## command line interface +I need concept for that. see [project_structure.md](project_structure.md) -We may also add: -1. `dlt run` or `dlt schedule` to run a pipeline in a script like cron would. +## logging +I need your input for user friendly logging. What should we log? What is important to see? ## pipeline runtime setup -1. logging -2. signals -3. unhandled exceptions \ No newline at end of file +1. logging - creates logger with the name `dlt` which can be disabled the python way if someone does not like it. (contrary to `dbt` logger which is uncontrollable mess) +2. signals - signals required to gracefully stop pipeline with CTRL-C, in docker, kubernetes, cron are handled. signals are not handled if `dlt` runs as part of `streamlit` app or a notebook. +3. unhandled exceptions - we do not catch unhandled exceptions... but we may do that if run in standalone script. \ No newline at end of file diff --git a/experiments/pipeline/examples/project_structure.md b/experiments/pipeline/examples/project_structure.md index e69de29bb2..a42a0c5d31 100644 --- a/experiments/pipeline/examples/project_structure.md +++ b/experiments/pipeline/examples/project_structure.md @@ -0,0 +1,9 @@ +1. `dlt init` to initialize new project and create project template for `create pipeline` use case. Should it also install `extras`? +2. `dlt deploy` to create deployment package (probably cron) + +I have two existing working commands +1. `dlt schema` to load and parse schema file and convert it into `json` or `yaml` +2. `dlt pipeline` to inspect a pipeline with a specified name/working folder + +We may also add: +1. `dlt run` or `dlt schedule` to run a pipeline in a script like cron would. diff --git a/experiments/pipeline/examples/secrets_and_config.md b/experiments/pipeline/examples/secrets_and_config.md index 032184530a..f3ec694f36 100644 --- a/experiments/pipeline/examples/secrets_and_config.md +++ b/experiments/pipeline/examples/secrets_and_config.md @@ -148,8 +148,8 @@ def google_sheets(spreadsheet_id: str, tab_names: List[str], credentials: TCrede return tabs ``` -Here I provide typing so I can type check injected values so no crap gets passed to the function. +Here I provide typing so I can type check injected values so no junk data gets passed to the function. -I also tell which argument is secret via `TCredentials` that let's me control for the case when user is putting secret values in `config.toml` or some other unsafe provider (and generate even better templates) +> I also tell which argument is secret via `TCredentials` that let's me control for the case when user is putting secret values in `config.toml` or some other unsafe provider (and generate even better templates) We could go even deeper here (ie. configurations `spec` may be explicitly declared via python `dataclasses`, may be embedded in one another etc. -> it comes useful when writing something really complicated) \ No newline at end of file diff --git a/experiments/pipeline/examples/working_with_schemas.md b/experiments/pipeline/examples/working_with_schemas.md new file mode 100644 index 0000000000..e69de29bb2 From 20a9186db2b37098b01af4df985c838df5661878 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 25 Oct 2022 12:48:32 +0200 Subject: [PATCH 45/66] adds project structure example --- experiments/pipeline/examples/README.md | 8 ++-- .../pipeline/examples/project_structure.md | 40 +++++++++++++++---- .../project_structure/.dlt/config.toml | 2 + .../examples/project_structure/.gitignore | 1 + .../examples/project_structure/README.md | 3 ++ .../examples/project_structure/__init__.py | 0 .../examples/project_structure/pipeline.py | 37 +++++++++++++++++ .../project_structure/requirements.txt | 2 + 8 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 experiments/pipeline/examples/project_structure/.dlt/config.toml create mode 100644 experiments/pipeline/examples/project_structure/.gitignore create mode 100644 experiments/pipeline/examples/project_structure/README.md create mode 100644 experiments/pipeline/examples/project_structure/__init__.py create mode 100644 experiments/pipeline/examples/project_structure/pipeline.py create mode 100644 experiments/pipeline/examples/project_structure/requirements.txt diff --git a/experiments/pipeline/examples/README.md b/experiments/pipeline/examples/README.md index 9ef139ce7b..965831d625 100644 --- a/experiments/pipeline/examples/README.md +++ b/experiments/pipeline/examples/README.md @@ -1,10 +1,10 @@ ## Finished documents 1. [general_usage.md](general_usage.md) -2. [create_pipeline.md](create_pipeline.md) -3. [secrets_and_config.md](secrets_and_config.md) +2. [project_structure.md](project_structure.md) & `dlt init` CLI +3. [create_pipeline.md](create_pipeline.md) +4. [secrets_and_config.md](secrets_and_config.md) ## In progress -1. [project_structure.md](project_structure.md) -2. [working_with_schemas.md](working_with_schemas.md) +1. [working_with_schemas.md](working_with_schemas.md) diff --git a/experiments/pipeline/examples/project_structure.md b/experiments/pipeline/examples/project_structure.md index a42a0c5d31..0dd18097f1 100644 --- a/experiments/pipeline/examples/project_structure.md +++ b/experiments/pipeline/examples/project_structure.md @@ -1,9 +1,35 @@ -1. `dlt init` to initialize new project and create project template for `create pipeline` use case. Should it also install `extras`? -2. `dlt deploy` to create deployment package (probably cron) +## Project structure for a create pipeline workflow -I have two existing working commands -1. `dlt schema` to load and parse schema file and convert it into `json` or `yaml` -2. `dlt pipeline` to inspect a pipeline with a specified name/working folder +Look into [project_structure](project_structure). It is a clone of template repository that we should have in our github. The files in the repository are parametrized with parameters of `dlt init` command. -We may also add: -1. `dlt run` or `dlt schedule` to run a pipeline in a script like cron would. +1. it contains `.dlt` folded with `config.toml` and `secrets.toml`. +2. we prefill those files with values corresponding to the destination +3. the requirements contain `python-dlt` and `requests` in `requirements.txt` +4. `.gitignore` for `secrets.toml` and `.env` (python virtual environment) +5. the pipeline script file `pipeline.py` containing template for a new source +6. `README.md` file with whatever content we need + + +## dlt init + +The prerequisites to run the command is to +1. create virtual environment +2. install `python-dlt` without extras + +> Question: any better ideas? I do not see anything simpler to go around. + +Proposed interface for the command: +`dlt init ` +Where `destination` must be one of our supported destination names: `bigquery` or `redshift` and source is alphanumeric string. + +Should be executed in an empty directory without `.git` or any other files. It will clone a template and create the project structure as above. The files in the project will be customized: + +1. `secrets.toml` will be prefilled with required credentials and secret values +2. `config.toml` will contain `pipeline_name` +3. the `pipeline.py` (1) will import the right destination (2) the source name will be changed to `_data` (3) the dataset name will be changed to `` etc. +4. `requirements.txt` will contain a proper dlt extras and requests library + +> Questions: +> 1. should we generate a working pipeline as a template (ie. with existing API) or a piece of code with instructions how to change it? +> 2. which features should we show in the template? parametrized source? providing api key and simple authentication? many resources? parametrized resources? configure export and import of schema yaml files? etc? +> 3. should we `pip install` the required extras ans requests when `dlt init` is run? diff --git a/experiments/pipeline/examples/project_structure/.dlt/config.toml b/experiments/pipeline/examples/project_structure/.dlt/config.toml new file mode 100644 index 0000000000..fd197222ac --- /dev/null +++ b/experiments/pipeline/examples/project_structure/.dlt/config.toml @@ -0,0 +1,2 @@ +pipeline_name="twitter" +# export_ diff --git a/experiments/pipeline/examples/project_structure/.gitignore b/experiments/pipeline/examples/project_structure/.gitignore new file mode 100644 index 0000000000..b3b3ed2cb7 --- /dev/null +++ b/experiments/pipeline/examples/project_structure/.gitignore @@ -0,0 +1 @@ +secrets.toml \ No newline at end of file diff --git a/experiments/pipeline/examples/project_structure/README.md b/experiments/pipeline/examples/project_structure/README.md new file mode 100644 index 0000000000..7e17869bc1 --- /dev/null +++ b/experiments/pipeline/examples/project_structure/README.md @@ -0,0 +1,3 @@ +# How to customize and deploy this pipeline? + +Maybe the training syllabus goes here? \ No newline at end of file diff --git a/experiments/pipeline/examples/project_structure/__init__.py b/experiments/pipeline/examples/project_structure/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/experiments/pipeline/examples/project_structure/pipeline.py b/experiments/pipeline/examples/project_structure/pipeline.py new file mode 100644 index 0000000000..3b2bb61c9b --- /dev/null +++ b/experiments/pipeline/examples/project_structure/pipeline.py @@ -0,0 +1,37 @@ +import requests +import dlt +from dlt.destinations import bigquery + + +# explain `dlt.source` a little here and last_id and api_key parameters +@dlt.source +def twitter_data(last_id, api_key): + # example of Bearer Authentication + # create authorization headers + headers = { + "Authorization": f"Bearer {api_key}" + } + + # explain the `dlt.resource` and the default table naming, write disposition etc. + @dlt.resource + def example_data(): + # make a call to the endpoint with request library + resp = requests.get("https://example.com/data?last_id=%i" % last_id, headers=headers) + resp.raise_for_status() + # yield the data from the resource + data = resp.json() + # you may process the data here + # example transformation? + # return resource to be loaded into `example_data` table + # explain that data["items"] contains a list of items + yield data["items"] + + # return all the resources to be loaded + return example_data + +# configure the pipeline +dlt.pipeline(destination=bigquery, dataset="twitter") +# explain that api_key will be automatically loaded from secrets.toml or environment variable below +load_info = twitter_data(0).run() +# pretty print the information on data that was loaded +print(load_info) diff --git a/experiments/pipeline/examples/project_structure/requirements.txt b/experiments/pipeline/examples/project_structure/requirements.txt new file mode 100644 index 0000000000..1ecee01bab --- /dev/null +++ b/experiments/pipeline/examples/project_structure/requirements.txt @@ -0,0 +1,2 @@ +python-dlt[bigquery]==0.1.0rc14 +requests \ No newline at end of file From 150da6d61e3e406558b647613c1b4bfcf0428a52 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 25 Oct 2022 14:19:26 +0200 Subject: [PATCH 46/66] adds working with schemas doc --- experiments/pipeline/examples/README.md | 3 +- .../pipeline/examples/working_with_schemas.md | 115 ++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/experiments/pipeline/examples/README.md b/experiments/pipeline/examples/README.md index 965831d625..9588996df6 100644 --- a/experiments/pipeline/examples/README.md +++ b/experiments/pipeline/examples/README.md @@ -4,7 +4,8 @@ 2. [project_structure.md](project_structure.md) & `dlt init` CLI 3. [create_pipeline.md](create_pipeline.md) 4. [secrets_and_config.md](secrets_and_config.md) +5. [working_with_schemas.md](working_with_schemas.md) ## In progress -1. [working_with_schemas.md](working_with_schemas.md) +I'll be writing advanced stuff later (ie. state and multi step pipelines etc.) diff --git a/experiments/pipeline/examples/working_with_schemas.md b/experiments/pipeline/examples/working_with_schemas.md index e69de29bb2..29d93cc6b1 100644 --- a/experiments/pipeline/examples/working_with_schemas.md +++ b/experiments/pipeline/examples/working_with_schemas.md @@ -0,0 +1,115 @@ +## General approach to define schemas + +## Schema components + +### Schema content hash and version +Each schema file contains content based hash `version_hash` that is used to +1. detect manual changes to schema (ie. user edits content) +2. detect if the destination database schema is synchronized with the file schema + +Each time the schema is saved, the version hash is updated. + +Each schema contains also numeric version which increases automatically whenever schema is updated and saved. This version is mostly for informative purposes and currently the user can easily reset it by wiping out the pipeline working dir (until we restore the current schema from the destination) + +> Currently the destination schema sync procedure uses the numeric version. I'm changing it to hash based versioning. + +### Normalizer and naming convention +The data normalizer and the naming convention are part of the schema configuration. In principle the source can set own naming convention or json unpacking mechanism. Or user can overwrite those in `config.toml` + +#### Relational normalizer config +Yes those are part of the normalizer module and can be plugged in. +1. column propagation from parent -> child +2. nesting level +3. parent -> child table linking type +### Global hints, preferred data type hints, data type autodetectors + +## Working with schema files +`dlt` automates working with schema files by setting up schema import and export folders. Settings are available via config providers (ie. `config.toml`) or via `dlt.pipeline(import_schema_path, export_schema_path)` settings. Example: +```python +dlt.pipeline(import_schema_path="schemas/import", export_schema_path="schemas/export") +``` +will create following folder structure in project root folder +``` +schemas + |---import/ + |---export/ +``` + +Which will expose pipeline schemas to the user in `yml` format. + +1. When new pipeline is created and source function is extracted for the first time a new schema is added to pipeline. This schema is created out of global hints and resource hints present in the source extractor function. It **does not depend on the data - which happens in normalize stage**. +2. Every such new schema will be saved to `import` folder (if not existing there already) and used as initial version for all future pipeline runs. +3. Once schema is present in `import` folder, **it is writable by the user only**. +4. Any change to the schemas in that folder are detected and propagated to the pipeline automatically on the next run (in fact any call to `Pipeline` object does that sync.). It means that after an user update, the schema in `import` folder resets all the automatic updates from the data. +4. Otherwise **the schema evolves automatically in the normalize stage** and each update is saved in `export` folder. The export folder is **writable by dlt only** and provides the actual view of the schema. +5. The `export` and `import` folders may be the same. In that case the evolved schema is automatically "accepted" as the initial one. + + +## Working with schema in code +`dlt` user can "check-out" any pipeline schema for modification in the code. + +> I do not have any cool API to work with the table, columns and other hints in the code - the schema is a typed dictionary and currently it is the only way. + +`dlt` will "commit" all the schema changes with any call to `run`, `extract`, `normalize` or `load` methods. + +Examples: + +```python +# extract some to "table" resource using default schema +p = dlt.pipeline(destination=redshift) +p.extract([1,2,3,4], name="table") +# get live schema +schema = p.default_schema +# we want the list data to be text, not integer +schema.tables["table"]["columns"]["value"] = schema_utils.new_column("value", "text") +# `run` will apply schema changes and run the normalizer and loader for already extracted data +p.run() +``` + +> The `normalize` stage creates standalone load packages each containing data and schema with particular version. Those packages are of course not impacted by the "live" schema changes. + +## Attaching schemas to sources +The general approach when creating a new pipeline is to setup a few global schema settings and then let the table and column schemas to be generated from the resource hints and data itself. + +> I do not have any cool "schema builder" api yet to se the global settings. + +Example: + +```python + +schema: Schema = None + +def setup_schema(nesting_level, hash_names_convention=False): + nonlocal schema + + # get default normalizer config + normalizer_conf = dlt.schema.normalizer_config() + # set hash names convention which produces short names without clashes but very ugly + if short_names_convention: + normalizer_conf["names"] = dlt.common.normalizers.names.hash_names + # remove date detector and add type detector that forces all fields to strings + normalizer_conf["detections"].remove("iso_timestamp") + normalizer_conf["detections"].insert(0, "all_text") + + # apply normalizer conf + schema = Schema("createx", normalizer_conf) + # set nesting level, yeah it's ugly + schema._normalizers_config["json"].setdefault("config", {})["max_nesting"] = nesting_level + +# apply schema to the source +@dlt.source(schema=schema) +def createx(): + ... + +``` + +Two other behaviors are supported +1. bare `dlt.source` will create empty schema with the source name +2. `dlt.source(name=...)` will first try to load `{name}_schema.yml` from the same folder the source python file exist. If not found, new empty schema will be created + + +## Open issues + +1. Name clashes. +2. Lack of lineage. +3. Names, types and hints interpretation depend on destination From 8282f304c0fb5ff6b3a7f1c61eceaaffcfdc5882 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 26 Oct 2022 19:52:40 +0200 Subject: [PATCH 47/66] first implementation of source and resource decorators, adds pipeline context --- experiments/pipeline/__init__.py | 46 ++- experiments/pipeline/decorators.py | 145 +++++++++ .../pipeline/examples/general_usage.md | 2 + experiments/pipeline/exceptions.py | 18 +- experiments/pipeline/pipeline.py | 278 +++++++++++++----- experiments/pipeline/typing.py | 2 +- 6 files changed, 398 insertions(+), 93 deletions(-) create mode 100644 experiments/pipeline/decorators.py diff --git a/experiments/pipeline/__init__.py b/experiments/pipeline/__init__.py index a334c8c90b..a9bc16ec79 100644 --- a/experiments/pipeline/__init__.py +++ b/experiments/pipeline/__init__.py @@ -1,13 +1,15 @@ import tempfile from typing import Union -from importlib import import_module from dlt.common.typing import TSecretValue, Any from dlt.common.configuration import with_config -from dlt.load.client_base import DestinationReference +from dlt.common.configuration.container import Container +from dlt.common.destination import DestinationReference, resolve_destination_reference +from dlt.common.pipeline import PipelineContext from experiments.pipeline.configuration import PipelineConfiguration from experiments.pipeline.pipeline import Pipeline +from experiments.pipeline.decorators import source, resource # @overload @@ -25,17 +27,41 @@ @with_config(spec=PipelineConfiguration, auto_namespace=True) -def configure(pipeline_name: str = None, working_dir: str = None, pipeline_secret: TSecretValue = None, destination: Union[None, str, DestinationReference] = None, **kwargs: Any) -> Pipeline: - print(locals()) +def pipeline( + pipeline_name: str = None, + working_dir: str = None, + pipeline_secret: TSecretValue = None, + destination: Union[None, str, DestinationReference] = None, + dataset_name: str = None, + import_schema_path: str = None, + export_schema_path: str = None, + always_drop_pipeline: bool = False, + **kwargs: Any +) -> Pipeline: + # call without parameters returns current pipeline + if not locals(): + context = Container()[PipelineContext] + # if pipeline instance is already active then return it, otherwise create a new one + if context.is_activated(): + return context.pipeline() + print(kwargs["_last_dlt_config"].pipeline_name) # if working_dir not provided use temp folder if not working_dir: working_dir = tempfile.gettempdir() - # if destination is a str, get destination reference by dynamically importing module from known location - if isinstance(destination, str): - destination = import_module(f"dlt.load.{destination}") + destination = resolve_destination_reference(destination) + # create new pipeline instance + p = Pipeline(pipeline_name, working_dir, pipeline_secret, destination, dataset_name, import_schema_path, export_schema_path, always_drop_pipeline, kwargs["runtime"]) + # set it as current pipeline + Container()[PipelineContext].activate(p) + + return p + +# setup default pipeline in the container +print("CONTEXT") +Container()[PipelineContext] = PipelineContext(pipeline) - return Pipeline(pipeline_name, working_dir, pipeline_secret, destination, kwargs["runtime"]) -def run() -> Pipeline: - return configure().extract() \ No newline at end of file +def run(source: Any, destination: Union[None, str, DestinationReference] = None) -> Pipeline: + destination = resolve_destination_reference(destination) + return pipeline().run(source=source, destination=destination) diff --git a/experiments/pipeline/decorators.py b/experiments/pipeline/decorators.py new file mode 100644 index 0000000000..38305d6582 --- /dev/null +++ b/experiments/pipeline/decorators.py @@ -0,0 +1,145 @@ +import inspect +from types import ModuleType +from makefun import wraps +from typing import Any, Dict, NamedTuple, Optional, Type + +from dlt.common.configuration import with_config, get_fun_spec +from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.exceptions import ArgumentsOverloadException +from dlt.common.schema.schema import Schema +from dlt.common.schema.typing import TTableSchemaColumns, TWriteDisposition +from dlt.common.source import TTableHintTemplate, TFunHintTemplate +from dlt.common.typing import AnyFun, TFun +from dlt.common.utils import is_inner_function +from dlt.extract.sources import DltResource, DltSource + + +class SourceInfo(NamedTuple): + SPEC: Type[BaseConfiguration] + f: AnyFun + module: ModuleType + + +_SOURCES: Dict[str, SourceInfo] = {} + + +def source(func: Optional[AnyFun] = None, /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None): + + if name and schema: + raise ArgumentsOverloadException("Source name cannot be set if schema is present") + + def decorator(f: TFun) -> TFun: + nonlocal schema, name + + # extract name + if schema: + name = schema.name + else: + name = name or f.__name__ + # create or load default schema + # TODO: we need a convention to load ie. load the schema from file with name_schema.yaml + schema = Schema(name) + + # wrap source extraction function in configuration with namespace + conf_f = with_config(f, spec=spec, namespaces=("source", name)) + + @wraps(conf_f, func_name=name) + def _wrap(*args: Any, **kwargs: Any) -> DltSource: + rv = conf_f(*args, **kwargs) + # if generator, consume it immediately + if inspect.isgenerator(rv): + rv = list(rv) + + def check_rv_type(rv: Any) -> None: + pass + + # check if return type is list or tuple + if isinstance(rv, (list, tuple)): + # check all returned elements + for v in rv: + check_rv_type(v) + else: + check_rv_type(rv) + + # convert to source + return DltSource.from_data(schema, rv) + + # get spec for wrapped function + SPEC = get_fun_spec(conf_f) + # store the source information + _SOURCES[_wrap.__qualname__] = SourceInfo(SPEC, _wrap, inspect.getmodule(f)) + + return _wrap + + if func is None: + # we're called with parens. + return decorator + + if not callable(func): + raise ValueError("First parameter to the source must be callable ie. by using it as function decorator") + + # we're called as @source without parens. + return decorator(func) + + +def resource( + data: Optional[Any] = None, + /, + name: TTableHintTemplate[str] = None, + table_name_fun: TFunHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + selected: bool = True, + depends_on: DltResource = None, + spec: Type[BaseConfiguration] = None): + + def make_resource(name, _data: Any) -> DltResource: + table_template = DltResource.new_table_template(table_name_fun or name, write_disposition=write_disposition, columns=columns) + return DltResource.from_data(_data, name, table_template, selected, depends_on) + + + def decorator(f: TFun) -> TFun: + resource_name = name or f.__name__ + + # if f is not a generator (does not yield) raise Exception + if not inspect.isgeneratorfunction(inspect.unwrap(f)): + raise ResourceFunNotGenerator() + + # do not inject config values for inner functions, we assume that they are part of the source + SPEC: Type[BaseConfiguration] = None + if is_inner_function(f): + conf_f = f + else: + print("USE SPEC -> GLOBAL") + # wrap source extraction function in configuration with namespace + conf_f = with_config(f, spec=spec, namespaces=("resource", resource_name)) + # get spec for wrapped function + SPEC = get_fun_spec(conf_f) + + @wraps(conf_f, func_name=resource_name) + def _wrap(*args: Any, **kwargs: Any) -> DltResource: + return make_resource(resource_name, f(*args, **kwargs)) + + # store the standalone resource information + if SPEC: + _SOURCES[_wrap.__qualname__] = SourceInfo(SPEC, _wrap, inspect.getmodule(f)) + + return _wrap + + + # if data is callable or none use decorator + if data is None: + # we're called with parens. + return decorator + + if callable(data): + return decorator(data) + else: + return make_resource(name, data) + + +def _get_source_for_inner_function(f: AnyFun) -> Optional[SourceInfo]: + # find source function + parts = f.__qualname__.split(".") + parent_fun = ".".join(parts[:-2]) + return _SOURCES.get(parent_fun) diff --git a/experiments/pipeline/examples/general_usage.md b/experiments/pipeline/examples/general_usage.md index 231960530e..c7b214375b 100644 --- a/experiments/pipeline/examples/general_usage.md +++ b/experiments/pipeline/examples/general_usage.md @@ -82,6 +82,8 @@ In case **there are more schemas in the pipeline**, the data will be loaded into ## pipeline working directory and state +the working directory of the pipeline will be **OS temporary folder/pipeline name** + Another fundamental concept is the pipeline working directory. This directory keeps the following information: 1. the extracted data and the load packages with jobs created by normalize 2. the current schemas with all the recent updates diff --git a/experiments/pipeline/exceptions.py b/experiments/pipeline/exceptions.py index af1df29f53..7cb58e59fa 100644 --- a/experiments/pipeline/exceptions.py +++ b/experiments/pipeline/exceptions.py @@ -1,5 +1,5 @@ from typing import Any, Sequence -from dlt.common.exceptions import DltException +from dlt.common.exceptions import DltException, ArgumentsOverloadException from dlt.common.telemetry import TRunMetrics from experiments.pipeline.typing import TPipelineStep @@ -43,14 +43,14 @@ def __init__(self, config_elem: str, step: TPipelineStep, help: str = None) -> N super().__init__(msg) -class PipelineConfiguredException(PipelineException): - def __init__(self, f_name: str) -> None: - super().__init__(f"{f_name} cannot be called on already configured or restored pipeline.") +# class PipelineConfiguredException(PipelineException): +# def __init__(self, f_name: str) -> None: +# super().__init__(f"{f_name} cannot be called on already configured or restored pipeline.") -class InvalidPipelineContextException(PipelineException): - def __init__(self) -> None: - super().__init__("There may be just one active pipeline in single python process. To activate current pipeline call `activate` method") +# class InvalidPipelineContextException(PipelineException): +# def __init__(self) -> None: +# super().__init__("There may be just one active pipeline in single python process. To activate current pipeline call `activate` method") class CannotRestorePipelineException(PipelineException): @@ -79,3 +79,7 @@ def __init__(self, step: TPipelineStep, exception: BaseException, run_metrics: T self.exception = exception self.run_metrics = run_metrics super().__init__(f"Pipeline execution failed at stage {step} with exception:\n\n{type(exception)}\n{exception}") + + +# class CannotApplyHintsToManyResources(ArgumentsOverloadException): +# pass diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py index e97a8666ec..5ffe84dba0 100644 --- a/experiments/pipeline/pipeline.py +++ b/experiments/pipeline/pipeline.py @@ -2,28 +2,31 @@ from contextlib import contextmanager from copy import deepcopy from functools import wraps +from collections.abc import Sequence as C_Sequence from typing import Any, Callable, ClassVar, List, Iterable, Iterator, Generator, Mapping, NewType, Optional, Sequence, Tuple, Type, TypedDict, Union, get_type_hints, overload from dlt.common import json, logger, signals from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext from dlt.common.runners.runnable import Runnable -from dlt.common.sources import DLT_METADATA_FIELD, TResolvableDataItem, with_table_name +from dlt.common.schema.typing import TColumnSchema, TWriteDisposition +from dlt.common.source import DLT_METADATA_FIELD, TResolvableDataItem, with_table_name from dlt.common.typing import DictStrAny, StrAny, TFun, TSecretValue, TAny from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner from dlt.common.storages import LiveSchemaStorage, NormalizeStorage from dlt.common.configuration import inject_namespace -from dlt.common.configuration.specs import RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, LoadVolumeConfiguration, PoolRunnerConfiguration, DestinationCapabilitiesContext -from dlt.common.schema.schema import Schema +from dlt.common.configuration.specs import RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, LoadVolumeConfiguration, PoolRunnerConfiguration +from dlt.common.destination import DestinationCapabilitiesContext, DestinationReference, JobClientBase, DestinationClientConfiguration, DestinationClientDwhConfiguration +from dlt.common.schema import Schema, utils as schema_utils from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import is_interactive from dlt.extract.extract import ExtractorStorage, extract from dlt.normalize import Normalize -from dlt.load.client_base import DestinationReference, JobClientBase, SqlClientBase -from dlt.load.configuration import DestinationClientConfiguration, DestinationClientDwhConfiguration, LoaderConfiguration +from dlt.load.client_base import SqlClientBase +from dlt.load.configuration import LoaderConfiguration from dlt.load import Load from dlt.normalize.configuration import NormalizeConfiguration @@ -39,14 +42,26 @@ class Pipeline: STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineState).keys()) pipeline_name: str - dataset_name: str default_schema_name: str + always_drop_pipeline: bool working_dir: str - def __init__(self, pipeline_name: str, working_dir: str, pipeline_secret: TSecretValue, destination: DestinationReference, runtime: RunConfiguration): + def __init__( + self, + pipeline_name: str, + working_dir: str, + pipeline_secret: TSecretValue, + destination: DestinationReference, + dataset_name: str, + import_schema_path: str, + export_schema_path: str, + always_drop_pipeline: bool, + runtime: RunConfiguration + ) -> None: self.pipeline_secret = pipeline_secret self.runtime_config = runtime self.destination = destination + self.dataset_name = dataset_name self.root_folder: str = None self._container = Container() @@ -59,7 +74,7 @@ def __init__(self, pipeline_name: str, working_dir: str, pipeline_secret: TSecre self._load_storage_config: LoadVolumeConfiguration = None initialize_runner(self.runtime_config) - self._configure(pipeline_name, working_dir) + self._configure(pipeline_name, working_dir, import_schema_path, export_schema_path, always_drop_pipeline) def with_state_sync(f: TFun) -> TFun: @@ -98,36 +113,27 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: return decorator - @with_state_sync - def _configure(self, pipeline_name: str, working_dir: str) -> None: - self.pipeline_name = pipeline_name - self.working_dir = working_dir - - # compute the folder that keeps all of the pipeline state - FileStorage.validate_file_name_component(self.pipeline_name) - self.root_folder = os.path.join(self.working_dir, self.pipeline_name) - # create default configs - # self._pool_config = PoolRunnerConfiguration(is_single_run=True, exit_on_exception=True) - self._schema_storage_config = SchemaVolumeConfiguration(schema_volume_path = os.path.join(self.root_folder, "schemas")) - self._normalize_storage_config = NormalizeVolumeConfiguration(normalize_volume_path=os.path.join(self.root_folder, "normalize")) - self._load_storage_config = LoadVolumeConfiguration(load_volume_path=os.path.join(self.root_folder, "load"),) - - # create pipeline working dir - self._pipeline_storage = FileStorage(self.root_folder, makedirs=False) - - # restore pipeline if folder exists and contains state - if self._pipeline_storage.has_file(Pipeline.STATE_FILE): - self._restore_pipeline() - else: - self._create_pipeline() - - # create schema storage - self._schema_storage = LiveSchemaStorage(self._schema_storage_config, makedirs=True) - def drop(self) -> "Pipeline": """Deletes existing pipeline state, schemas and drops datasets at the destination if present""" - pass + # drop the data for all known schemas + for schema in self._schema_storage: + with self._get_destination_client(schema) as client: + client.initialize_storage(wipe_data=True) + # reset the pipeline working dir + self._create_pipeline() + # clone the pipeline + return Pipeline( + self.pipeline_name, + self.working_dir, + self.pipeline_secret, + self.destination, + self.dataset_name, + self._schema_storage.config.import_schema_path, + self._schema_storage.config.export_schema_path, + self.always_drop_pipeline, + self.runtime_config + ) # @overload @@ -158,22 +164,23 @@ def drop(self) -> "Pipeline": @with_config_namespace(("extract",)) def extract( self, - data: Union[DltSource, DltResource, Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], - table_name = None, - write_disposition = None, - parent = None, - columns = None, + data: Any, + table_name: str, + parent_table_name: str = None, + write_disposition: TWriteDisposition = None, + columns: Sequence[TColumnSchema] = None, schema: Schema = None, *, max_parallel_items: int = 100, workers: int = 5 ) -> None: - def only_data_args(with_schema: bool) -> None: - if not table_name or not write_disposition or not parent or not columns: - raise InvalidExtractArguments(with_schema) - if not with_schema and not schema: - raise InvalidExtractArguments(with_schema) + + # def has_hint_args() -> bool: + # return table_name or parent_table_name or write_disposition or schema + + def apply_hint_args(resource: DltResource) -> None: + resource.apply_hints(table_name, parent_table_name, write_disposition, columns) def choose_schema() -> Schema: if schema: @@ -182,30 +189,50 @@ def choose_schema() -> Schema: return self.default_schema return Schema(self.pipeline_name) - source: DltSource = None - - if isinstance(data, DltSource): - # already a source - only_data_args(with_schema=False) - source = data - elif isinstance(data, DltResource): - # package resource in source - only_data_args(with_schema=True) - source = DltSource(choose_schema(), [data]) - else: - table_schema: TTableSchemaTemplate = { - "name": table_name, - "parent": parent, - "write_disposition": write_disposition, - "columns": columns - } - # convert iterable to resource - data = DltResource.from_data(data, name=table_name, table_schema_template=table_schema) + # a list of sources or a list of resources may be passed as data + sources: List[DltSource] = [] + + def item_to_source(data_item: Any) -> DltSource: + if isinstance(data_item, DltSource): + # if schema is explicit then override source schema + if schema: + data_item.schema = schema + # try to apply hints to resources + resources = data_item.resources + for r in resources: + apply_hint_args(r) + return data_item + + if isinstance(data_item, DltResource): + # apply hints + apply_hint_args(data_item) + # package resource in source + return DltSource(choose_schema(), [data_item]) + + # iterator/iterable/generator + # create resource first without table template + resource = DltResource.from_data(data_item, name=table_name) + # apply hints + apply_hint_args(resource) # wrap resource in source - source = DltSource(choose_schema(), [data]) + return DltSource(choose_schema(), [resource]) + + if isinstance(data, C_Sequence) and len(data) > 0: + # if first element is source or resource + if isinstance(data[0], DltResource): + sources.append(item_to_source(DltSource(choose_schema(), data))) + elif isinstance(data[0], DltSource): + for s in data: + sources.append(item_to_source(s)) + else: + sources.append(item_to_source(data)) + else: + sources.append(item_to_source(data)) try: - self._extract_source(source, max_parallel_items, workers) + # extract all sources + for s in sources: + self._extract_source(s, max_parallel_items, workers) except Exception as exc: raise PipelineStepFailed("extract", self.last_run_exception, runner.LAST_RUN_METRICS) from exc @@ -215,6 +242,9 @@ def choose_schema() -> Schema: def normalize(self, workers: int = 1, dry_run: bool = False) -> None: if is_interactive() and workers > 1: raise NotImplementedError("Do not use normalize workers in interactive mode ie. in notebook") + # check if any schema is present, if not then no data was extracted + if not self.default_schema_name: + return # get destination capabilities destination_caps = self._get_destination_capabilities() @@ -244,7 +274,7 @@ def load( credentials: Any = None, # raise_on_failed_jobs = False, # raise_on_incompatible_schema = False, - # always_drop_dataset = False, + always_wipe_storage = False, *, workers: int = 20 ) -> None: @@ -265,11 +295,36 @@ def load( is_single_run=True, exit_on_exception=True, workers=workers, + always_wipe_storage=always_wipe_storage or self.always_drop_pipeline, load_storage_config=self._load_storage_config ) load = Load(self.destination, is_storage_owner=False, config=load_config, initial_client_config=client_initial_config) self._run_step_in_pool("load", load, load.config) + @with_config_namespace(("run",)) + def run( + self, + source: Any = None, + destination: DestinationReference = None, + dataset_name: str = None, + table_name: str = None, + write_disposition: TWriteDisposition = None, + columns: Sequence[TColumnSchema] = None, + schema: Schema = None + ) -> None: + # set destination and default dataset if provided + self.destination = destination or self.destination + self.dataset_name = dataset_name or self.dataset_name + # normalize and load pending data + self.normalize() + self.load(destination, dataset_name) + + # extract from the source + if source: + self.extract(source, table_name, write_disposition, None, columns, schema) + self.normalize() + self.load(destination, dataset_name) + @property def schemas(self) -> Mapping[str, Schema]: return self._schema_storage @@ -282,8 +337,43 @@ def default_schema(self) -> Schema: def last_run_exception(self) -> BaseException: return runner.LAST_RUN_EXCEPTION + @with_state_sync + def _configure(self, pipeline_name: str, working_dir: str, import_schema_path: str, export_schema_path: str, always_drop_pipeline: bool) -> None: + self.pipeline_name = pipeline_name + self.working_dir = working_dir + self.always_drop_pipeline = always_drop_pipeline + + # compute the folder that keeps all of the pipeline state + FileStorage.validate_file_name_component(self.pipeline_name) + self.root_folder = os.path.join(self.working_dir, self.pipeline_name) + # create default configs + # self._pool_config = PoolRunnerConfiguration(is_single_run=True, exit_on_exception=True) + self._schema_storage_config = SchemaVolumeConfiguration( + schema_volume_path=os.path.join(self.root_folder, "schemas"), + import_schema_path=import_schema_path, + export_schema_path=export_schema_path + ) + self._normalize_storage_config = NormalizeVolumeConfiguration(normalize_volume_path=os.path.join(self.root_folder, "normalize")) + self._load_storage_config = LoadVolumeConfiguration(load_volume_path=os.path.join(self.root_folder, "load"),) + + # create pipeline working dir + self._pipeline_storage = FileStorage(self.root_folder, makedirs=False) + + # restore pipeline if folder exists and contains state + if self._pipeline_storage.has_file(Pipeline.STATE_FILE) and not always_drop_pipeline: + self._restore_pipeline() + else: + # this will erase the existing working folder + self._create_pipeline() + + # create schema storage + self._schema_storage = LiveSchemaStorage(self._schema_storage_config, makedirs=True) + def _create_pipeline(self) -> None: - self._pipeline_storage.create_folder(".", exists_ok=True) + # kill everything inside the working folder + if self._pipeline_storage.has_folder(""): + self._pipeline_storage.delete_folder("", recursively=True) + self._pipeline_storage.create_folder("", exists_ok=False) def _restore_pipeline(self) -> None: self._restore_state() @@ -294,16 +384,25 @@ def _restore_state(self) -> None: self._state.update(restored_state) def _extract_source(self, source: DltSource, max_parallel_items: int, workers: int) -> None: - storage = ExtractorStorage(self._normalize_storage_config) + # discover the schema from source + source_schema = source.discover_schema() + # iterate over all items in the pipeline and update the schema if dynamic table hints were present + storage = ExtractorStorage(self._normalize_storage_config) for _, partials in extract(source, storage, max_parallel_items=max_parallel_items, workers=workers).items(): for partial in partials: - source.schema.update_schema(source.schema.normalize_table_identifiers(partial)) + source_schema.update_schema(source_schema.normalize_table_identifiers(partial)) + + # if source schema does not exist in the pipeline + if source_schema.name not in self._schema_storage: + # possibly initialize the import schema if it is a new schema + self._schema_storage.initialize_import_if_new(source_schema) + # save schema into the pipeline + self._schema_storage.save_schema(source_schema) + # and set as default if this is first schema in pipeline + if not self.default_schema_name: + self.default_schema_name = source_schema.name - # save schema and set as default if this is first one - self._schema_storage.save_schema(source.schema) - if not self.default_schema_name: - self.default_schema_name = source.schema.name def _run_step_in_pool(self, step: TPipelineStep, runnable: Runnable[Any], config: PoolRunnerConfiguration) -> int: try: @@ -318,7 +417,33 @@ def _run_step_in_pool(self, step: TPipelineStep, runnable: Runnable[Any], config finally: signals.raise_if_signalled() - def _get_destination_client_initial_config(self, credentials: Any) -> DestinationClientConfiguration: + def _run_f_in_pool(self, run_f: Callable[..., Any], config: PoolRunnerConfiguration) -> int: + # internal runners should work in single mode + self._loader_instance.config.is_single_run = True + self._loader_instance.config.exit_on_exception = True + self._normalize_instance.config.is_single_run = True + self._normalize_instance.config.exit_on_exception = True + + def _run(_: Any) -> TRunMetrics: + rv = run_f() + if isinstance(rv, TRunMetrics): + return rv + if isinstance(rv, int): + pending = rv + else: + pending = 1 + return TRunMetrics(False, False, int(pending)) + + # run the fun + ec = runner.run_pool(config, _run) + # ec > 0 - signalled + # -1 - runner was not able to start + + if runner.LAST_RUN_METRICS is not None and runner.LAST_RUN_METRICS.has_failed: + raise self.last_run_exception + return ec + + def _get_destination_client_initial_config(self, credentials: Any = None) -> DestinationClientConfiguration: if not self.destination: raise PipelineConfigMissing( "destination", @@ -336,6 +461,9 @@ def _get_destination_client_initial_config(self, credentials: Any) -> Destinatio def _get_destination_client(self, schema: Schema, initial_config: DestinationClientConfiguration = None) -> JobClientBase: try: + # config is not provided then get it with injected credentials + if not initial_config: + initial_config = self._get_destination_client_initial_config() return self.destination.client(schema, initial_config) except ImportError: client_spec = self.destination.spec() diff --git a/experiments/pipeline/typing.py b/experiments/pipeline/typing.py index 9fbf917265..adb9f2f6e2 100644 --- a/experiments/pipeline/typing.py +++ b/experiments/pipeline/typing.py @@ -7,7 +7,7 @@ class TPipelineState(TypedDict): pipeline_name: str dataset_name: str default_schema_name: Optional[str] - # destination_name: Optional[str] + # destination: Optional[str] # TSourceState = NewType("TSourceState", DictStrAny) From 8ca3b9bf8178997c102389da3ccc6a7f577471e0 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 26 Oct 2022 19:53:13 +0200 Subject: [PATCH 48/66] extracts destination and pipeline common code --- dlt/__init__.py | 2 +- dlt/common/configuration/__init__.py | 2 +- dlt/common/configuration/container.py | 3 + dlt/common/configuration/exceptions.py | 2 +- dlt/common/configuration/inject.py | 9 ++ dlt/common/configuration/specs/__init__.py | 1 - .../specs/destination_capabilities_context.py | 25 ---- .../specs/schema_volume_configuration.py | 2 +- dlt/common/data_writers/buffered.py | 2 +- dlt/common/data_writers/writers.py | 2 +- dlt/common/destination.py | 138 ++++++++++++++++++ dlt/common/exceptions.py | 8 +- dlt/common/normalizers/json/relational.py | 2 +- dlt/common/pipeline.py | 45 ++++++ dlt/common/schema/__init__.py | 2 +- dlt/common/schema/schema.py | 11 +- dlt/common/schema/typing.py | 6 +- dlt/common/schema/utils.py | 35 ++++- dlt/common/{sources.py => source.py} | 4 +- dlt/common/storages/data_item_storage.py | 2 +- dlt/common/storages/live_schema_storage.py | 10 +- dlt/common/typing.py | 4 +- dlt/common/utils.py | 8 +- dlt/dbt_runner/runner.py | 2 +- dlt/extract/extract.py | 2 +- dlt/extract/pipe.py | 9 +- dlt/extract/sources.py | 135 +++++++++++++---- dlt/load/bigquery/__init__.py | 4 +- dlt/load/bigquery/bigquery.py | 7 +- dlt/load/bigquery/configuration.py | 3 +- dlt/load/client_base.py | 94 +----------- dlt/load/client_base_impl.py | 6 +- dlt/load/configuration.py | 13 +- dlt/load/dummy/__init__.py | 4 +- dlt/load/dummy/configuration.py | 4 +- dlt/load/dummy/dummy.py | 4 +- dlt/load/exceptions.py | 3 +- dlt/load/load.py | 7 +- dlt/load/redshift/__init__.py | 4 +- dlt/load/redshift/configuration.py | 3 +- dlt/load/redshift/redshift.py | 11 +- dlt/load/typing.py | 1 - dlt/normalize/configuration.py | 3 +- dlt/normalize/normalize.py | 2 +- dlt/pipeline/pipeline.py | 2 +- examples/sources/rasa_tracker_store.py | 2 +- examples/sources/singer_tap.py | 2 +- tests/common/configuration/test_inject.py | 6 +- tests/common/schema/test_inference.py | 2 +- tests/common/storages/test_schema_storage.py | 4 + tests/conftest.py | 2 + tests/load/test_client.py | 2 +- tests/load/test_dummy_client.py | 3 +- tests/load/utils.py | 8 +- tests/normalize/test_normalize.py | 4 +- 55 files changed, 430 insertions(+), 253 deletions(-) delete mode 100644 dlt/common/configuration/specs/destination_capabilities_context.py create mode 100644 dlt/common/destination.py create mode 100644 dlt/common/pipeline.py rename dlt/common/{sources.py => source.py} (94%) diff --git a/dlt/__init__.py b/dlt/__init__.py index a68927d6ca..3dc1f76bc6 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0" \ No newline at end of file +__version__ = "0.1.0" diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 34e590ac49..ccc176c807 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,6 +1,6 @@ from .specs.base_configuration import configspec, is_valid_hint # noqa: F401 from .resolve import resolve_configuration, inject_namespace # noqa: F401 -from .inject import with_config, last_config +from .inject import with_config, last_config, get_fun_spec from .exceptions import ( # noqa: F401 ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, ConfigIntegrityException, ConfigFileNotFoundException) diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index 1f0d180c45..fff80d79ed 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -37,6 +37,9 @@ def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration: return item # type: ignore + def __setitem__(self, spec: Type[TConfiguration], value: TConfiguration) -> None: + self.contexts[spec] = value + def __contains__(self, spec: Type[TConfiguration]) -> bool: return spec in self.contexts diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index b0c5967440..ee70ec9c62 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -102,7 +102,7 @@ def __init__(self, spec: Type[Any], existing_config: Any, expected_config: Any) super().__init__(f"When restoring context {spec.__name__}, instance {expected_config} was expected, instead instance {existing_config} was found.") -class ContextDefaultCannotBeCreated(ConfigurationException): +class ContextDefaultCannotBeCreated(ConfigurationException, KeyError): def __init__(self, spec: Type[Any]) -> None: self.spec = spec super().__init__(f"Container cannot create the default value of context {spec.__name__}.") diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index e7c0ed20c8..52eef51b3d 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -14,6 +14,12 @@ _SLEEPING_CAT_SPLIT = re.compile("[^.^_]+") _LAST_DLT_CONFIG = "_last_dlt_config" TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) +# keep a registry of all the decorated functions +_FUNC_SPECS: Dict[str, Type[BaseConfiguration]] = {} + + +def get_fun_spec(f: AnyFun) -> Type[BaseConfiguration]: + return _FUNC_SPECS.get(id(f)) @overload @@ -97,6 +103,9 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: # call the function with resolved config return f(*bound_args.args, **bound_args.kwargs) + # register the spec for a wrapped function + _FUNC_SPECS[id(_wrap)] = SPEC + return _wrap # type: ignore # See if we're being called as @with_config or @with_config(). diff --git a/dlt/common/configuration/specs/__init__.py b/dlt/common/configuration/specs/__init__.py index c5efbc46a8..bd48a01909 100644 --- a/dlt/common/configuration/specs/__init__.py +++ b/dlt/common/configuration/specs/__init__.py @@ -6,5 +6,4 @@ from .pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 from .gcp_client_credentials import GcpClientCredentials # noqa: F401 from .postgres_credentials import PostgresCredentials # noqa: F401 -from .destination_capabilities_context import DestinationCapabilitiesContext # noqa: F401 from .config_namespace_context import ConfigNamespacesContext # noqa: F401 \ No newline at end of file diff --git a/dlt/common/configuration/specs/destination_capabilities_context.py b/dlt/common/configuration/specs/destination_capabilities_context.py deleted file mode 100644 index a5832f383b..0000000000 --- a/dlt/common/configuration/specs/destination_capabilities_context.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List, ClassVar, Literal - -from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec - -# known loader file formats -# jsonl - new line separated json documents -# puae-jsonl - internal extract -> normalize format bases on jsonl -# insert_values - insert SQL statements -TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values"] - - -@configspec(init=True) -class DestinationCapabilitiesContext(ContainerInjectableContext): - """Injectable destination capabilities required for many Pipeline stages ie. normalize""" - preferred_loader_file_format: TLoaderFileFormat - supported_loader_file_formats: List[TLoaderFileFormat] - max_identifier_length: int - max_column_length: int - max_query_length: int - is_max_query_length_in_bytes: bool - max_text_data_type_length: int - is_max_text_data_type_length_in_bytes: bool - - # do not allow to create default value, destination caps must be always explicitly inserted into container - can_create_default: ClassVar[bool] = False diff --git a/dlt/common/configuration/specs/schema_volume_configuration.py b/dlt/common/configuration/specs/schema_volume_configuration.py index 324b2e418f..a5b70d3068 100644 --- a/dlt/common/configuration/specs/schema_volume_configuration.py +++ b/dlt/common/configuration/specs/schema_volume_configuration.py @@ -14,5 +14,5 @@ class SchemaVolumeConfiguration(BaseConfiguration): external_schema_format_remove_defaults: bool = True # remove default values when exporting schema if TYPE_CHECKING: - def __init__(self, schema_volume_path: str = None) -> None: + def __init__(self, schema_volume_path: str = None, import_schema_path: str = None, export_schema_path: str = None) -> None: ... diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 7b5e3f1054..bbb6a380b2 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -2,7 +2,7 @@ from dlt.common.utils import uniq_id from dlt.common.typing import TDataItem -from dlt.common.sources import TDirectDataItem +from dlt.common.source import TDirectDataItem from dlt.common.data_writers import TLoaderFileFormat from dlt.common.data_writers.exceptions import BufferedDataWriterClosed, InvalidFileNameTemplateException from dlt.common.data_writers.writers import DataWriter diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 716b4bdf99..399e76c506 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -9,7 +9,7 @@ from dlt.common.json import json_typed_dumps from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal -from dlt.common.configuration.specs.destination_capabilities_context import TLoaderFileFormat +from dlt.common.destination import TLoaderFileFormat @dataclass diff --git a/dlt/common/destination.py b/dlt/common/destination.py new file mode 100644 index 0000000000..67ad218e5d --- /dev/null +++ b/dlt/common/destination.py @@ -0,0 +1,138 @@ +from abc import ABC, abstractmethod +from importlib import import_module +from types import TracebackType +from typing import ClassVar, List, Optional, Literal, Type, Protocol, Union + +from dlt.common.schema import Schema +from dlt.common.schema.typing import TTableSchema +from dlt.common.typing import ConfigValue +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration, ContainerInjectableContext + + +# known loader file formats +# jsonl - new line separated json documents +# puae-jsonl - internal extract -> normalize format bases on jsonl +# insert_values - insert SQL statements +TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values"] + + +@configspec(init=True) +class DestinationCapabilitiesContext(ContainerInjectableContext): + """Injectable destination capabilities required for many Pipeline stages ie. normalize""" + preferred_loader_file_format: TLoaderFileFormat + supported_loader_file_formats: List[TLoaderFileFormat] + max_identifier_length: int + max_column_length: int + max_query_length: int + is_max_query_length_in_bytes: bool + max_text_data_type_length: int + is_max_text_data_type_length_in_bytes: bool + + # do not allow to create default value, destination caps must be always explicitly inserted into container + can_create_default: ClassVar[bool] = False + + +@configspec(init=True) +class DestinationClientConfiguration(BaseConfiguration): + destination_name: str = None # which destination to load data to + credentials: Optional[CredentialsConfiguration] + + +@configspec(init=True) +class DestinationClientDwhConfiguration(DestinationClientConfiguration): + dataset_name: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix + default_schema_name: Optional[str] = None # name of default schema to be used to name effective dataset to load data to + + +TLoadJobStatus = Literal["running", "failed", "retry", "completed"] + + +class LoadJob: + """Represents a job that loads a single file + + Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". + Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. + In "running" state, the loader component periodically gets the state via `status()` method. When terminal state is reached, load job is discarded and not called again. + `exception` method is called to get error information in "failed" and "retry" states. + + The `__init__` method is responsible to put the Job in "running" state. It may raise `LoadClientTerminalException` and `LoadClientTransientException` tp + immediately transition job into "failed" or "retry" state respectively. + """ + def __init__(self, file_name: str) -> None: + """ + File name is also a job id (or job id is deterministically derived) so it must be globally unique + """ + self._file_name = file_name + + @abstractmethod + def status(self) -> TLoadJobStatus: + pass + + @abstractmethod + def file_name(self) -> str: + pass + + @abstractmethod + def exception(self) -> str: + pass + + +class JobClientBase(ABC): + def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> None: + self.schema = schema + self.config = config + + @abstractmethod + def initialize_storage(self) -> None: + pass + + @abstractmethod + def update_storage_schema(self) -> None: + pass + + @abstractmethod + def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: + pass + + @abstractmethod + def restore_file_load(self, file_path: str) -> LoadJob: + pass + + @abstractmethod + def complete_load(self, load_id: str) -> None: + pass + + @abstractmethod + def __enter__(self) -> "JobClientBase": + pass + + @abstractmethod + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + pass + + @classmethod + @abstractmethod + def capabilities(cls) -> DestinationCapabilitiesContext: + pass + + +class DestinationReference(Protocol): + def capabilities(self) -> DestinationCapabilitiesContext: + ... + + def client(self, schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> "JobClientBase": + ... + + def spec(self) -> Type[DestinationClientConfiguration]: + ... + + +def resolve_destination_reference(destination: Union[None, str, DestinationReference]) -> DestinationReference: + if destination is None: + return None + + if isinstance(destination, str): + # TODO: figure out if this is full module path name or name of one of the known destinations + # if destination is a str, get destination reference by dynamically importing module from known location + return import_module(f"dlt.load.{destination}") \ No newline at end of file diff --git a/dlt/common/exceptions.py b/dlt/common/exceptions.py index 440310c302..21e44450d0 100644 --- a/dlt/common/exceptions.py +++ b/dlt/common/exceptions.py @@ -76,8 +76,14 @@ def __init__(self, start_ts: float, end_ts: float) -> None: class DictValidationException(DltException): - def __init__(self, msg: str, path: str, field: str = None, value: Any = None): + def __init__(self, msg: str, path: str, field: str = None, value: Any = None) -> None: self.path = path self.field = field self.value = value super().__init__(msg) + + +class ArgumentsOverloadException(DltException): + def __init__(self, msg: str, *args: str) -> None: + self.args = args + super().__init__(msg) diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 622457c20d..ad58de1659 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -6,7 +6,7 @@ from dlt.common.schema.utils import column_name_validator from dlt.common.utils import uniq_id, digest128 from dlt.common.normalizers.json import TNormalizedRowIterator, wrap_in_dict -from dlt.common.sources import TEventDLTMeta +from dlt.common.source import TEventDLTMeta from dlt.common.validation import validate_dict diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py new file mode 100644 index 0000000000..d9fd25d711 --- /dev/null +++ b/dlt/common/pipeline.py @@ -0,0 +1,45 @@ +from typing import Any, Callable, ClassVar, Protocol, Sequence + +from dlt.common.configuration.container import ContainerInjectableContext +from dlt.common.configuration import configspec +from dlt.common.destination import DestinationReference +from dlt.common.schema import Schema +from dlt.common.schema.typing import TColumnSchema, TWriteDisposition + + +class SupportsPipeline(Protocol): + """A protocol with core pipeline operations that lets high level abstractions ie. sources to access pipeline methods and properties""" + def run( + self, + source: Any = None, + destination: DestinationReference = None, + dataset_name: str = None, + table_name: str = None, + write_disposition: TWriteDisposition = None, + columns: Sequence[TColumnSchema] = None, + schema: Schema = None + ) -> Any: + ... + + +@configspec(init=True) +class PipelineContext(ContainerInjectableContext): + _deferred_pipeline: Any + _pipeline: Any + + can_create_default: ClassVar[bool] = False + + def pipeline(self) -> SupportsPipeline: + if not self._pipeline: + # delayed pipeline creation + self._pipeline = self._deferred_pipeline() + return self._pipeline + + def activate(self, pipeline: SupportsPipeline) -> None: + self._pipeline = pipeline + + def is_activated(self) -> bool: + return self._pipeline is not None + + def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline]) -> None: + self._deferred_pipeline = deferred_pipeline diff --git a/dlt/common/schema/__init__.py b/dlt/common/schema/__init__.py index ebbe4dfcaa..80f008f432 100644 --- a/dlt/common/schema/__init__.py +++ b/dlt/common/schema/__init__.py @@ -1,4 +1,4 @@ -from dlt.common.schema.typing import TSchemaUpdate, TStoredSchema, TTableSchemaColumns, TDataType, THintType, TColumnSchema, TColumnSchemaBase # noqa: F401 +from dlt.common.schema.typing import TSchemaUpdate, TStoredSchema, TTableSchemaColumns, TDataType, TColumnHint, TColumnSchema, TColumnSchemaBase # noqa: F401 from dlt.common.schema.typing import COLUMN_HINTS # noqa: F401 from dlt.common.schema.schema import Schema # noqa: F401 from dlt.common.schema.utils import add_missing_hints, verify_schema_hash # noqa: F401 diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 9382f572cf..2a15ca1a74 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -9,7 +9,7 @@ from dlt.common.normalizers.json import TNormalizeJSONFunc from dlt.common.schema.typing import (TNormalizersConfig, TPartialTableSchema, TSchemaSettings, TSimpleRegex, TStoredSchema, TSchemaTables, TTableSchema, TTableSchemaColumns, TColumnSchema, TColumnProp, TDataType, - THintType, TWriteDisposition) + TColumnHint, TWriteDisposition) from dlt.common.schema import utils from dlt.common.schema.exceptions import (CannotCoerceColumnException, CannotCoerceNullException, InvalidSchemaName, ParentTableNotFoundException, SchemaCorruptedException) @@ -35,7 +35,7 @@ def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: # list of preferred types: map regex on columns into types self._compiled_preferred_types: List[Tuple[REPattern, TDataType]] = [] # compiled default hints - self._compiled_hints: Dict[THintType, Sequence[REPattern]] = {} + self._compiled_hints: Dict[TColumnHint, Sequence[REPattern]] = {} # compiled exclude filters per table self._compiled_excludes: Dict[str, Sequence[REPattern]] = {} # compiled include filters per table @@ -209,7 +209,7 @@ def bump_version(self) -> Tuple[int, str]: self._stored_version, self._stored_version_hash = version return version - def filter_row_with_hint(self, table_name: str, hint_type: THintType, row: StrAny) -> StrAny: + def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: StrAny) -> StrAny: rv_row: DictStrAny = {} column_prop: TColumnProp = utils.hint_to_column_prop(hint_type) try: @@ -227,7 +227,7 @@ def filter_row_with_hint(self, table_name: str, hint_type: THintType, row: StrAn # dicts are ordered and we will return the rows with hints in the same order as they appear in the columns return rv_row - def merge_hints(self, new_hints: Mapping[THintType, Sequence[TSimpleRegex]]) -> None: + def merge_hints(self, new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]]) -> None: # validate regexes validate_dict(TSchemaSettings, {"default_hints": new_hints}, ".", validator_f=utils.simple_regex_validator) # prepare hints to be added @@ -407,7 +407,7 @@ def _infer_column_type(self, v: Any, col_name: str) -> TDataType: pass return mapped_type - def _infer_hint(self, hint_type: THintType, _: Any, col_name: str) -> bool: + def _infer_hint(self, hint_type: TColumnHint, _: Any, col_name: str) -> bool: if hint_type in self._compiled_hints: return any(h.search(col_name) for h in self._compiled_hints[hint_type]) else: @@ -425,6 +425,7 @@ def _add_standard_hints(self) -> None: def _configure_normalizers(self) -> None: if not self._normalizers_config: # create default normalizer config + # TODO: pass default normalizers as context or as config with defaults self._normalizers_config = utils.default_normalizers() # import desired modules naming_module = import_module(self._normalizers_config["names"]) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index c587b864b6..f7078e6e7a 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -4,7 +4,7 @@ TDataType = Literal["text", "double", "bool", "timestamp", "bigint", "binary", "complex", "decimal", "wei"] -THintType = Literal["not_null", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique"] +TColumnHint = Literal["not_null", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique"] TColumnProp = Literal["name", "data_type", "nullable", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique"] TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTypeDetections = Literal["timestamp", "iso_timestamp", "large_integer", "hexbytes_to_text", "wei_to_double"] @@ -12,7 +12,7 @@ DATA_TYPES: Set[TDataType] = set(get_args(TDataType)) COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) -COLUMN_HINTS: Set[THintType] = set(["partition", "cluster", "primary_key", "foreign_key", "sort", "unique"]) +COLUMN_HINTS: Set[TColumnHint] = set(["partition", "cluster", "primary_key", "foreign_key", "sort", "unique"]) WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition)) @@ -74,7 +74,7 @@ class TNormalizersConfig(TypedDict, total=True): class TSchemaSettings(TypedDict, total=False): schema_sealed: Optional[bool] - default_hints: Optional[Dict[THintType, List[TSimpleRegex]]] + default_hints: Optional[Dict[TColumnHint, List[TSimpleRegex]]] preferred_types: Optional[Dict[TSimpleRegex, TDataType]] diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index c722134689..4a7d2f4f2f 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -19,7 +19,7 @@ from dlt.common.utils import str2bool from dlt.common.validation import TCustomValidator, validate_dict from dlt.common.schema import detections -from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX, TColumnName, TNormalizersConfig, TPartialTableSchema, TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, TDataType, THintType, TTypeDetectionFunc, TTypeDetections, TWriteDisposition +from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX, TColumnName, TNormalizersConfig, TPartialTableSchema, TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, TDataType, TColumnHint, TTypeDetectionFunc, TTypeDetections, TWriteDisposition from dlt.common.schema.exceptions import CannotCoerceColumnException, ParentTableNotFoundException, SchemaEngineNoUpgradePathException, SchemaException, TablePropertiesConflictException @@ -181,7 +181,7 @@ def upgrade_engine_version(schema_dict: DictStrAny, from_engine: int, to_engine: } } # move settings, convert strings to simple regexes - d_h: Dict[THintType, List[TSimpleRegex]] = schema_dict.pop("hints", {}) + d_h: Dict[TColumnHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) for h_k, h_l in d_h.items(): d_h[h_k] = list(map(lambda r: TSimpleRegex("re:" + r), h_l)) p_t: Dict[TSimpleRegex, TDataType] = schema_dict.pop("preferred_types", {}) @@ -491,7 +491,7 @@ def compare_column(a: TColumnSchema, b: TColumnSchema) -> bool: return a["data_type"] == b["data_type"] and a["nullable"] == b["nullable"] -def hint_to_column_prop(h: THintType) -> TColumnProp: +def hint_to_column_prop(h: TColumnHint) -> TColumnProp: if h == "not_null": return "nullable" return h @@ -545,21 +545,40 @@ def load_table() -> TTableSchema: return table -def new_table(table_name: str, parent_name: str = None, write_disposition: TWriteDisposition = None, columns: Sequence[TColumnSchema] = None) -> TTableSchema: +def new_table( + table_name: str, + parent_table_name: str = None, + write_disposition: TWriteDisposition = None, + columns: Sequence[TColumnSchema] = None, + validate_schema: bool = False +) -> TTableSchema: + table: TTableSchema = { "name": table_name, "columns": {} if columns is None else {c["name"]: add_missing_hints(c) for c in columns} } - if parent_name: - table["parent"] = parent_name + if parent_table_name: + table["parent"] = parent_table_name assert write_disposition is None else: # set write disposition only for root tables table["write_disposition"] = write_disposition or DEFAULT_WRITE_DISPOSITION - # print(f"new table {table_name} cid {id(table['columns'])}") + if validate_schema: + validate_dict(TTableSchema, table, f"new_table/{table_name}") return table +def new_column(column_name: str, data_type: TDataType, nullable: bool = True, validate_schema: bool = False) -> TColumnSchema: + column = add_missing_hints({ + "name": column_name, + "data_type": data_type, + "nullable": nullable + }) + if validate_schema: + validate_dict(TColumnSchema, column, f"new_column/{column_name}") + return column + + def default_normalizers() -> TNormalizersConfig: return { "detections": ["timestamp", "iso_timestamp"], @@ -570,5 +589,5 @@ def default_normalizers() -> TNormalizersConfig: } -def standard_hints() -> Dict[THintType, List[TSimpleRegex]]: +def standard_hints() -> Dict[TColumnHint, List[TSimpleRegex]]: return None diff --git a/dlt/common/sources.py b/dlt/common/source.py similarity index 94% rename from dlt/common/sources.py rename to dlt/common/source.py index ebeb1f4a70..2bda51a5bb 100644 --- a/dlt/common/sources.py +++ b/dlt/common/source.py @@ -26,7 +26,9 @@ TAwaitableDataItem = Awaitable[TDirectDataItem] TResolvableDataItem = Union[TDirectDataItem, TDeferredDataItem, TAwaitableDataItem] -TFunDataItemDynHint = Callable[[TDataItem], Any] +TDynHintType = TypeVar("TDynHintType") +TFunHintTemplate = Callable[[TDataItem], TDynHintType] +TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]] # name of dlt metadata as part of the item DLT_METADATA_FIELD = "_dlt_meta" diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index 9c4e05249c..d3304d1418 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -3,7 +3,7 @@ from dlt.common import logger from dlt.common.schema import TTableSchemaColumns -from dlt.common.sources import TDirectDataItem +from dlt.common.source import TDirectDataItem from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index 0a3551b4eb..af7ce33af9 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -36,10 +36,18 @@ def load_schema(self, name: str) -> Schema: def save_schema(self, schema: Schema) -> str: rv = super().save_schema(schema) - # update the live schema with schema being saved but to not create live instance if not already present + # update the live schema with schema being saved but do not create live instance if not already present self._update_live_schema(schema, False) return rv + def initialize_import_if_new(self, schema: Schema) -> None: + if self.config.import_schema_path and schema.name not in self: + try: + self._load_import_schema(schema.name) + except FileNotFoundError: + # save import schema only if it not exist + self._export_schema(schema, self.config.import_schema_path) + def commit_live_schema(self, name: str) -> Schema: # if live schema exists and is modified then it must be used as an import schema live_schema = self.live_schemas.get(name) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 24bc16de9c..889220f72b 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -24,9 +24,9 @@ TAny = TypeVar("TAny", bound=Any) TAnyClass = TypeVar("TAnyClass", bound=object) TSecretValue = NewType("TSecretValue", str) # represent secret value ie. coming from Kubernetes/Docker secrets or other providers -TDataItem = Any # a single data item extracted from data source, normalized and loaded +TDataItem: TypeAlias = Any # a single data item extracted from data source, normalized and loaded -ConfigValue: None = None +ConfigValue: None = None # a value of type None indicating argument that may be injected by config provider TVariantBase = TypeVar("TVariantBase", covariant=True) TVariantRV = Tuple[str, Any] diff --git a/dlt/common/utils.py b/dlt/common/utils.py index dbeed70f14..18d65d45ac 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Iterator, Optional, Sequence, TypeVar, Mapping, List, TypedDict, Union -from dlt.common.typing import StrAny, DictStrAny, StrStr, TFun +from dlt.common.typing import AnyFun, StrAny, DictStrAny, StrStr, TFun T = TypeVar("T") @@ -182,3 +182,9 @@ def entry_point_file_stem() -> str: if len(sys.argv) > 0 and os.path.isfile(sys.argv[0]): return Path(sys.argv[0]).stem return None + + +def is_inner_function(f: AnyFun) -> bool: + """Checks if f is defined within other function""" + # inner functions have full nesting path in their qualname + return "" in f.__qualname__ diff --git a/dlt/dbt_runner/runner.py b/dlt/dbt_runner/runner.py index 07c6e1edc5..9d34d360ce 100644 --- a/dlt/dbt_runner/runner.py +++ b/dlt/dbt_runner/runner.py @@ -183,7 +183,7 @@ def configure(C: DBTRunnerConfiguration, collector: CollectorRegistry) -> None: model_elapsed_gauge, model_exec_info = create_gauges(REGISTRY) except ValueError as v: # ignore re-creation of gauges - if "Duplicated time-series" not in str(v): + if "Duplicated" not in str(v): raise diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index c048880948..c9a5476d50 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -2,7 +2,7 @@ from typing import ClassVar, List from dlt.common.utils import uniq_id -from dlt.common.sources import TDirectDataItem, TDataItem +from dlt.common.source import TDirectDataItem, TDataItem from dlt.common.schema import utils, TSchemaUpdate from dlt.common.storages import NormalizeStorage, DataItemStorage from dlt.common.configuration.specs import NormalizeVolumeConfiguration diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index bc22e77ec6..c70f1b684e 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -9,7 +9,7 @@ from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec from dlt.common.typing import TDataItem -from dlt.common.sources import TDirectDataItem, TResolvableDataItem +from dlt.common.source import TDirectDataItem, TResolvableDataItem if TYPE_CHECKING: TItemFuture = Future[TDirectDataItem] @@ -99,10 +99,10 @@ def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = No self.parent = parent @classmethod - def from_iterable(cls, name: str, gen: Union[Iterable[TResolvableDataItem], Iterator[TResolvableDataItem]]) -> "Pipe": + def from_iterable(cls, name: str, gen: Union[Iterable[TResolvableDataItem], Iterator[TResolvableDataItem]], parent: "Pipe" = None) -> "Pipe": if isinstance(gen, Iterable): gen = iter(gen) - return cls(name, [gen]) + return cls(name, [gen], parent=parent) @property def head(self) -> TPipeStep: @@ -215,6 +215,7 @@ def __init__(self, max_parallel_items: int, workers: int, futures_poll_interval: def from_pipe(cls, pipe: Pipe, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": if pipe.parent: pipe = pipe.full_pipe() + # TODO: if pipe head is callable then call it now # head must be iterator assert isinstance(pipe.head, Iterator) # create extractor @@ -245,7 +246,7 @@ def _fork_pipeline(pipe: Pipe) -> None: # add every head as source only once if not any(i.pipe == pipe for i in extract._sources): print("add to sources: " + pipe.name) - extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) + extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) for pipe in reversed(pipes): _fork_pipeline(pipe) diff --git a/dlt/extract/sources.py b/dlt/extract/sources.py index 8f8eef272e..ac25825215 100644 --- a/dlt/extract/sources.py +++ b/dlt/extract/sources.py @@ -1,47 +1,60 @@ import contextlib from copy import deepcopy import inspect -from typing import AsyncIterable, AsyncIterator, Coroutine, Dict, Generator, Iterable, Iterator, List, Set, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any -from dlt.common.exceptions import DltException -from dlt.common.schema.utils import new_table +from typing import AsyncIterable, AsyncIterator, Coroutine, Dict, Generator, Iterable, Iterator, List, NamedTuple, Set, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any +from dlt.common.exceptions import DltException from dlt.common.typing import TDataItem -from dlt.common.sources import TFunDataItemDynHint, TDirectDataItem -from dlt.common.schema.schema import Schema +from dlt.common.source import TFunHintTemplate, TDirectDataItem, TTableHintTemplate +from dlt.common.schema import Schema +from dlt.common.schema.utils import new_table from dlt.common.schema.typing import TPartialTableSchema, TTableSchema, TTableSchemaColumns, TWriteDisposition +from dlt.common.configuration.container import Container +from dlt.common.pipeline import PipelineContext from dlt.extract.pipe import FilterItem, Pipe, CreatePipeException, PipeIterator +# class HintArgs(NamedTuple): +# table_name: TTableHintTemplate[str] +# parent_table_name: TTableHintTemplate[str] = None +# write_disposition: TTableHintTemplate[TWriteDisposition] = None +# columns: TTableHintTemplate[TTableSchemaColumns] = None + + +# def apply_args(args: HintArgs): +# pass + +# apply_args() + class TTableSchemaTemplate(TypedDict, total=False): - name: Union[str, TFunDataItemDynHint] - description: Union[str, TFunDataItemDynHint] - write_disposition: Union[TWriteDisposition, TFunDataItemDynHint] + name: TTableHintTemplate[str] + description: TTableHintTemplate[str] + write_disposition: TTableHintTemplate[TWriteDisposition] # table_sealed: Optional[bool] - parent: Union[str, TFunDataItemDynHint] - columns: Union[TTableSchemaColumns, TFunDataItemDynHint] + parent: TTableHintTemplate[str] + columns: TTableHintTemplate[TTableSchemaColumns] class DltResourceSchema: def __init__(self, name: str, table_schema_template: TTableSchemaTemplate = None): # self.__name__ = name self.name = name - self._table_name_hint_fun: TFunDataItemDynHint = None + self._table_name_hint_fun: TFunHintTemplate[str] = None self._table_has_other_dynamic_hints: bool = False self._table_schema_template: TTableSchemaTemplate = None self._table_schema: TPartialTableSchema = None if table_schema_template: - self._set_template(table_schema_template) + self.set_template(table_schema_template) def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: - if not self._table_schema_template: # if table template is not present, generate partial table from name if not self._table_schema: self._table_schema = new_table(self.name) return self._table_schema - def _resolve_hint(hint: Union[Any, TFunDataItemDynHint]) -> Any: + def _resolve_hint(hint: TTableHintTemplate[Any]) -> Any: if callable(hint): return hint(item) else: @@ -52,50 +65,101 @@ def _resolve_hint(hint: Union[Any, TFunDataItemDynHint]) -> Any: if item is None: raise DataItemRequiredForDynamicTableHints(self.name) else: - cloned_template = deepcopy(self._table_schema_template) - return cast(TPartialTableSchema, {k: _resolve_hint(v) for k, v in cloned_template.items()}) + # cloned_template = deepcopy(self._table_schema_template) + return cast(TPartialTableSchema, {k: _resolve_hint(v) for k, v in self._table_schema_template.items()}) else: return cast(TPartialTableSchema, self._table_schema_template) - def _set_template(self, table_schema_template: TTableSchemaTemplate) -> None: - # validate template - # TODO: name must be set if any other properties are set - # TODO: remove all none values - + def apply_hints( + self, + table_name: TTableHintTemplate[str] = None, + parent_table_name: TTableHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + ) -> None: + t = None + if not self._table_schema_template: + # if there's no template yet, create and set new one + t = self.new_table_template(table_name, parent_table_name, write_disposition, columns) + else: + # set single hints + t = deepcopy(self._table_schema_template) + if table_name: + t["name"] = table_name + if parent_table_name: + t["parent"] = parent_table_name + if write_disposition: + t["write_disposition"] = write_disposition + if columns: + t["columns"] = columns + self.set_template(t) + + def set_template(self, table_schema_template: TTableSchemaTemplate) -> None: # if "name" is callable in the template then the table schema requires actual data item to be inferred - name_hint = table_schema_template.get("name") + name_hint = table_schema_template["name"] if callable(name_hint): self._table_name_hint_fun = name_hint + else: + self._table_name_hint_fun = None # check if any other hints in the table template should be inferred from data self._table_has_other_dynamic_hints = any(callable(v) for k, v in table_schema_template.items() if k != "name") - - if self._table_has_other_dynamic_hints and not self._table_name_hint_fun: - raise InvalidTableSchemaTemplate("Table name must be a function if any other table hint is a function") self._table_schema_template = table_schema_template + @staticmethod + def new_table_template( + table_name: TTableHintTemplate[str], + parent_table_name: TTableHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + ) -> TTableSchemaTemplate: + if not table_name: + raise InvalidTableSchemaTemplate("Table template name must be a string or function taking TDataItem") + # create a table schema template where hints can be functions taking TDataItem + new_template: TTableSchemaTemplate = new_table(table_name, parent_table_name, write_disposition=write_disposition, columns=columns) # type: ignore + # if any of the hints is a function then name must be as well + if any(callable(v) for k, v in new_template.items() if k != "name") and not callable(table_name): + raise InvalidTableSchemaTemplate("Table name must be a function if any other table hint is a function") + return new_template class DltResource(Iterable[TDirectDataItem], DltResourceSchema): - def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate): + def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate, selected: bool): self.name = pipe.name + self.selected = selected self._pipe = pipe super().__init__(self.name, table_schema_template) @classmethod - def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSchemaTemplate = None) -> "DltResource": + def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSchemaTemplate = None, selected: bool = True, depends_on: "DltResource" = None) -> "DltResource": # call functions assuming that they do not take any parameters, typically they are generator functions if callable(data): + # use inspect.isgeneratorfunction to see if this is generator or not + # if it is then call it, if not then keep the callable assuming that it will return iterable/iterator + # if inspect.isgeneratorfunction(data): + # data = data() + # else: data = data() if isinstance(data, DltResource): return data if isinstance(data, Pipe): - return cls(data, table_schema_template) + return cls(data, table_schema_template, selected) # several iterable types are not allowed and must be excluded right away if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) + # check if depends_on is a valid resource + parent_pipe: Pipe = None + if depends_on: + if not isinstance(depends_on, DltResource): + # if this is generator function provide nicer exception + if inspect.isgeneratorfunction(inspect.unwrap(depends_on)): + raise ParentResourceIsGeneratorFunction() + else: + raise ParentNotAResource() + parent_pipe = depends_on._pipe + # create resource from iterator or iterable if isinstance(data, (Iterable, Iterator)): if inspect.isgenerator(data): @@ -103,9 +167,9 @@ def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSch else: name = name or None if not name: - raise ResourceNameRequired("The DltResource name was not provide or could not be inferred.") - pipe = Pipe.from_iterable(name, data) - return cls(pipe, table_schema_template) + raise ResourceNameRequired("The DltResource name was not provided or could not be inferred.") + pipe = Pipe.from_iterable(name, data, parent=parent_pipe) + return cls(pipe, table_schema_template, selected) # some other data type that is not supported raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) @@ -143,7 +207,7 @@ def __init__(self, schema: Schema, resources: Sequence[DltResource] = None) -> N self.name = schema.name self._schema = schema self._resources: List[DltResource] = list(resources or []) - self._enabled_resource_names: Set[str] = set(r.name for r in self._resources) + self._enabled_resource_names: Set[str] = set(r.name for r in self._resources if r.selected) @classmethod def from_data(cls, schema: Schema, data: Any) -> "DltSource": @@ -185,6 +249,10 @@ def pipes(self) -> Sequence[Pipe]: def schema(self) -> Schema: return self._schema + @schema.setter + def schema(self, value: Schema) -> None: + self._schema = value + def discover_schema(self) -> Schema: # extract tables from all resources and update internal schema for r in self._resources: @@ -202,6 +270,9 @@ def select(self, *resource_names: str) -> "DltSource": return self + def run(self, destination: Any) -> Any: + return Container()[PipelineContext].pipeline().run(source=self, destination=destination) + def __iter__(self) -> Iterator[TDirectDataItem]: return map(lambda item: item.item, PipeIterator.from_pipes(self.pipes)) diff --git a/dlt/load/bigquery/__init__.py b/dlt/load/bigquery/__init__.py index e004fe7afd..b14aefc661 100644 --- a/dlt/load/bigquery/__init__.py +++ b/dlt/load/bigquery/__init__.py @@ -3,10 +3,8 @@ from dlt.common.schema.schema import Schema from dlt.common.typing import ConfigValue from dlt.common.configuration import with_config -from dlt.common.configuration.specs import DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext, JobClientBase, DestinationClientConfiguration -from dlt.load.client_base import JobClientBase -from dlt.load.configuration import DestinationClientConfiguration from dlt.load.bigquery.configuration import BigQueryClientConfiguration diff --git a/dlt/load/bigquery/bigquery.py b/dlt/load/bigquery/bigquery.py index 60996a0591..0b228192e4 100644 --- a/dlt/load/bigquery/bigquery.py +++ b/dlt/load/bigquery/bigquery.py @@ -13,12 +13,13 @@ from dlt.common.typing import StrAny from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.configuration.specs import GcpClientCredentials, DestinationCapabilitiesContext +from dlt.common.configuration.specs import GcpClientCredentials +from dlt.common.destination import DestinationCapabilitiesContext, TLoadJobStatus, LoadJob from dlt.common.data_writers import escape_bigquery_identifier from dlt.common.schema import TColumnSchema, TDataType, Schema, TTableSchemaColumns -from dlt.load.typing import TLoadJobStatus, DBCursor -from dlt.load.client_base import SqlClientBase, LoadJob +from dlt.load.typing import DBCursor +from dlt.load.client_base import SqlClientBase from dlt.load.client_base_impl import SqlJobClientBase from dlt.load.exceptions import LoadClientSchemaWillNotUpdate, LoadJobNotExistsException, LoadJobServerTerminalException, LoadUnknownTableException diff --git a/dlt/load/bigquery/configuration.py b/dlt/load/bigquery/configuration.py index ab66d48e45..325ff6430f 100644 --- a/dlt/load/bigquery/configuration.py +++ b/dlt/load/bigquery/configuration.py @@ -5,8 +5,7 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import GcpClientCredentials from dlt.common.configuration.exceptions import ConfigEntryMissingException - -from dlt.load.configuration import DestinationClientDwhConfiguration +from dlt.common.destination import DestinationClientDwhConfiguration @configspec(init=True) diff --git a/dlt/load/client_base.py b/dlt/load/client_base.py index b694128432..9ede7e6483 100644 --- a/dlt/load/client_base.py +++ b/dlt/load/client_base.py @@ -2,100 +2,8 @@ from contextlib import contextmanager from types import TracebackType from typing import Any, ContextManager, Generic, Iterator, Optional, Sequence, Tuple, Type, AnyStr, Protocol -from pathlib import Path -from dlt.common.schema import Schema -from dlt.common.schema.typing import TTableSchema -from dlt.common.typing import ConfigValue -from dlt.common.configuration.specs import DestinationCapabilitiesContext - -from dlt.load.configuration import DestinationClientConfiguration -from dlt.load.typing import TLoadJobStatus, TNativeConn, DBCursor - - -class LoadJob: - """Represents a job that loads a single file - - Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". - Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. - In "running" state, the loader component periodically gets the state via `status()` method. When terminal state is reached, load job is discarded and not called again. - `exception` method is called to get error information in "failed" and "retry" states. - - The `__init__` method is responsible to put the Job in "running" state. It may raise `LoadClientTerminalException` and `LoadClientTransientException` tp - immediately transition job into "failed" or "retry" state respectively. - """ - def __init__(self, file_name: str) -> None: - """ - File name is also a job id (or job id is deterministically derived) so it must be globally unique - """ - self._file_name = file_name - - @abstractmethod - def status(self) -> TLoadJobStatus: - pass - - @abstractmethod - def file_name(self) -> str: - pass - - @abstractmethod - def exception(self) -> str: - pass - - -class JobClientBase(ABC): - def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> None: - self.schema = schema - self.config = config - - @abstractmethod - def initialize_storage(self) -> None: - pass - - @abstractmethod - def update_storage_schema(self) -> None: - pass - - @abstractmethod - def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: - pass - - @abstractmethod - def restore_file_load(self, file_path: str) -> LoadJob: - pass - - @abstractmethod - def complete_load(self, load_id: str) -> None: - pass - - @abstractmethod - def __enter__(self) -> "JobClientBase": - pass - - @abstractmethod - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: - pass - - @classmethod - @abstractmethod - def capabilities(cls) -> DestinationCapabilitiesContext: - pass - - # @classmethod - # @abstractmethod - # def configure(cls, initial_values: StrAny = None) -> Tuple[BaseConfiguration, CredentialsConfiguration]: - # pass - - -class DestinationReference(Protocol): - def capabilities(self) -> DestinationCapabilitiesContext: - ... - - def client(self, schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> "JobClientBase": - ... - - def spec(self) -> Type[DestinationClientConfiguration]: - ... +from dlt.load.typing import TNativeConn, DBCursor class SqlClientBase(ABC, Generic[TNativeConn]): diff --git a/dlt/load/client_base_impl.py b/dlt/load/client_base_impl.py index 894e8fb605..c681a61992 100644 --- a/dlt/load/client_base_impl.py +++ b/dlt/load/client_base_impl.py @@ -5,10 +5,10 @@ from dlt.common import pendulum, logger from dlt.common.storages import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns +from dlt.common.destination import DestinationClientConfiguration, TLoadJobStatus, LoadJob, JobClientBase -from dlt.load.typing import TLoadJobStatus, TNativeConn -from dlt.load.client_base import LoadJob, JobClientBase, SqlClientBase -from dlt.load.configuration import DestinationClientConfiguration +from dlt.load.typing import TNativeConn +from dlt.load.client_base import SqlClientBase from dlt.load.exceptions import LoadClientSchemaVersionCorrupted diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index bdf79eb604..22a985a00d 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -5,20 +5,9 @@ from dlt.common.configuration.specs.load_volume_configuration import LoadVolumeConfiguration -@configspec(init=True) -class DestinationClientConfiguration(BaseConfiguration): - destination_name: str = None # which destination to load data to - credentials: Optional[CredentialsConfiguration] - - -@configspec(init=True) -class DestinationClientDwhConfiguration(DestinationClientConfiguration): - dataset_name: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix - default_schema_name: Optional[str] = None # name of default schema to be used to name effective dataset to load data to - - @configspec(init=True) class LoaderConfiguration(PoolRunnerConfiguration): workers: int = 20 # how many parallel loads can be executed pool_type: TPoolType = "thread" # mostly i/o (upload) so may be thread pool + always_wipe_storage: bool = False # removes all data in the storage load_storage_config: LoadVolumeConfiguration = None diff --git a/dlt/load/dummy/__init__.py b/dlt/load/dummy/__init__.py index 2ac57d272a..d4906310b1 100644 --- a/dlt/load/dummy/__init__.py +++ b/dlt/load/dummy/__init__.py @@ -3,10 +3,8 @@ from dlt.common.schema.schema import Schema from dlt.common.typing import ConfigValue from dlt.common.configuration import with_config -from dlt.common.configuration.specs import DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext, JobClientBase, DestinationClientConfiguration -from dlt.load.client_base import JobClientBase -from dlt.load.configuration import DestinationClientConfiguration from dlt.load.dummy.configuration import DummyClientConfiguration diff --git a/dlt/load/dummy/configuration.py b/dlt/load/dummy/configuration.py index dc301ecb73..f39180e1e4 100644 --- a/dlt/load/dummy/configuration.py +++ b/dlt/load/dummy/configuration.py @@ -1,7 +1,5 @@ from dlt.common.configuration import configspec -from dlt.common.data_writers import TLoaderFileFormat - -from dlt.load.configuration import DestinationClientConfiguration +from dlt.common.destination import DestinationClientConfiguration, TLoaderFileFormat @configspec(init=True) diff --git a/dlt/load/dummy/dummy.py b/dlt/load/dummy/dummy.py index e3ca263cbb..1976b89b65 100644 --- a/dlt/load/dummy/dummy.py +++ b/dlt/load/dummy/dummy.py @@ -6,10 +6,8 @@ from dlt.common.schema import Schema from dlt.common.storages import FileStorage from dlt.common.schema.typing import TTableSchema -from dlt.common.configuration.specs import DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext, JobClientBase, LoadJob, TLoadJobStatus -from dlt.load.client_base import JobClientBase, LoadJob -from dlt.load.typing import TLoadJobStatus from dlt.load.exceptions import (LoadJobNotExistsException, LoadJobInvalidStateTransitionException, LoadClientTerminalException, LoadClientTransientException) diff --git a/dlt/load/exceptions.py b/dlt/load/exceptions.py index 7944943b24..60f4b8008d 100644 --- a/dlt/load/exceptions.py +++ b/dlt/load/exceptions.py @@ -1,7 +1,6 @@ from typing import Sequence from dlt.common.exceptions import DltException, TerminalException, TransientException - -from dlt.load.typing import TLoadJobStatus +from dlt.common.destination import TLoadJobStatus class LoadException(DltException): diff --git a/dlt/load/load.py b/dlt/load/load.py index 3980a3203c..a867915bd8 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -12,11 +12,10 @@ from dlt.common.schema.typing import TTableSchema from dlt.common.storages import LoadStorage from dlt.common.telemetry import get_logging_extras, set_gauge_all_labels +from dlt.common.destination import JobClientBase, DestinationReference, LoadJob, TLoadJobStatus, DestinationClientConfiguration -from dlt.load.client_base import JobClientBase, DestinationReference, LoadJob from dlt.load.client_base_impl import LoadEmptyJob -from dlt.load.typing import TLoadJobStatus -from dlt.load.configuration import LoaderConfiguration, DestinationClientConfiguration +from dlt.load.configuration import LoaderConfiguration from dlt.load.exceptions import LoadClientTerminalException, LoadClientTransientException, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, LoadJobNotExistsException, LoadUnknownTableException @@ -46,7 +45,7 @@ def __init__( Load.create_gauges(collector) except ValueError as v: # ignore re-creation of gauges - if "Duplicated timeseries" not in str(v): + if "Duplicated" not in str(v): raise def create_storage(self, is_storage_owner: bool) -> LoadStorage: diff --git a/dlt/load/redshift/__init__.py b/dlt/load/redshift/__init__.py index 0123092700..45723d5540 100644 --- a/dlt/load/redshift/__init__.py +++ b/dlt/load/redshift/__init__.py @@ -3,10 +3,8 @@ from dlt.common.schema.schema import Schema from dlt.common.typing import ConfigValue from dlt.common.configuration import with_config -from dlt.common.configuration.specs import DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext, JobClientBase, DestinationClientConfiguration -from dlt.load.client_base import JobClientBase -from dlt.load.configuration import DestinationClientConfiguration from dlt.load.redshift.configuration import RedshiftClientConfiguration diff --git a/dlt/load/redshift/configuration.py b/dlt/load/redshift/configuration.py index fe75be5457..ce724eec4a 100644 --- a/dlt/load/redshift/configuration.py +++ b/dlt/load/redshift/configuration.py @@ -1,7 +1,6 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import PostgresCredentials - -from dlt.load.configuration import DestinationClientDwhConfiguration +from dlt.common.destination import DestinationClientDwhConfiguration @configspec(init=True) diff --git a/dlt/load/redshift/redshift.py b/dlt/load/redshift/redshift.py index 136d06c1a5..be43841218 100644 --- a/dlt/load/redshift/redshift.py +++ b/dlt/load/redshift/redshift.py @@ -11,15 +11,16 @@ from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.configuration.specs import PostgresCredentials, DestinationCapabilitiesContext +from dlt.common.configuration.specs import PostgresCredentials +from dlt.common.destination import DestinationCapabilitiesContext, LoadJob, TLoadJobStatus from dlt.common.data_writers import escape_redshift_identifier -from dlt.common.schema import COLUMN_HINTS, TColumnSchema, TColumnSchemaBase, TDataType, THintType, Schema, TTableSchemaColumns, add_missing_hints +from dlt.common.schema import COLUMN_HINTS, TColumnSchema, TColumnSchemaBase, TDataType, TColumnHint, Schema, TTableSchemaColumns, add_missing_hints from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.storages.file_storage import FileStorage from dlt.load.exceptions import LoadClientSchemaWillNotUpdate, LoadClientTerminalInnerException, LoadClientTransientInnerException -from dlt.load.typing import TLoadJobStatus, DBCursor -from dlt.load.client_base import SqlClientBase, LoadJob +from dlt.load.typing import DBCursor +from dlt.load.client_base import SqlClientBase from dlt.load.client_base_impl import SqlJobClientBase, LoadEmptyJob from dlt.load.redshift import capabilities @@ -47,7 +48,7 @@ "numeric": "decimal" } -HINT_TO_REDSHIFT_ATTR: Dict[THintType, str] = { +HINT_TO_REDSHIFT_ATTR: Dict[TColumnHint, str] = { "cluster": "DISTKEY", # it is better to not enforce constraints in redshift # "primary_key": "PRIMARY KEY", diff --git a/dlt/load/typing.py b/dlt/load/typing.py index d6ec5f457a..f576ae5c97 100644 --- a/dlt/load/typing.py +++ b/dlt/load/typing.py @@ -1,6 +1,5 @@ from typing import Any, AnyStr, List, Literal, Optional, Tuple, TypeVar -TLoadJobStatus = Literal["running", "failed", "retry", "completed"] # native connection TNativeConn = TypeVar("TNativeConn", bound="object") diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index 1dd4b3b659..df4540aa0b 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -1,5 +1,6 @@ from dlt.common.configuration import configspec -from dlt.common.configuration.specs import LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, PoolRunnerConfiguration, DestinationCapabilitiesContext, TPoolType +from dlt.common.configuration.specs import LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, PoolRunnerConfiguration, TPoolType +from dlt.common.destination import DestinationCapabilitiesContext @configspec(init=True) diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 3ff8f8d6ef..4954343f02 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -52,7 +52,7 @@ def __init__(self, collector: CollectorRegistry = REGISTRY, schema_storage: Sche self.create_gauges(collector) except ValueError as v: # ignore re-creation of gauges - if "Duplicated time-series" not in str(v): + if "Duplicated" not in str(v): raise @staticmethod diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index e055717312..27cc83acef 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -17,7 +17,7 @@ from dlt.common.schema import Schema from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id, is_interactive -from dlt.common.sources import DLT_METADATA_FIELD, TItem, with_table_name +from dlt.common.source import DLT_METADATA_FIELD, TItem, with_table_name from dlt.extract.extractor_storage import ExtractorStorageBase from dlt.load.client_base import SqlClientBase, SqlJobClientBase diff --git a/examples/sources/rasa_tracker_store.py b/examples/sources/rasa_tracker_store.py index eae691fb01..45f6c468d8 100644 --- a/examples/sources/rasa_tracker_store.py +++ b/examples/sources/rasa_tracker_store.py @@ -1,5 +1,5 @@ from typing import Iterator -from dlt.common.sources import with_table_name +from dlt.common.source import with_table_name from dlt.common.typing import DictStrAny from dlt.common.time import timestamp_within diff --git a/examples/sources/singer_tap.py b/examples/sources/singer_tap.py index d23402e807..c04b94bb00 100644 --- a/examples/sources/singer_tap.py +++ b/examples/sources/singer_tap.py @@ -4,7 +4,7 @@ from dlt.common import json from dlt.common.runners.venv import Venv -from dlt.common.sources import with_table_name +from dlt.common.source import with_table_name from dlt.common.typing import DictStrAny, StrAny, StrOrBytesPath from examples.sources.stdout import get_source as get_singer_pipe diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index d889263e2a..fa4f72a4a1 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -3,7 +3,7 @@ from dlt.common import Decimal from dlt.common.typing import TSecretValue -from dlt.common.configuration.inject import _spec_from_signature, _get_spec_name_from_f, with_config +from dlt.common.configuration.inject import _spec_from_signature, _get_spec_name_from_f, get_fun_spec, with_config from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration from tests.utils import preserve_environ @@ -127,6 +127,10 @@ def f(pipeline_name, value): f("pipe") + # make sure the spec is available for decorated fun + assert get_fun_spec(f) is not None + assert hasattr(get_fun_spec(f), "pipeline_name") + def test_inject_with_spec() -> None: pass diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index efb2650ce8..635e6a47d2 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -338,7 +338,7 @@ def test_infer_with_autodetection(schema: Schema) -> None: def test_update_schema_parent_missing(schema: Schema) -> None: - tab1 = utils.new_table("tab1", parent_name="tab_parent") + tab1 = utils.new_table("tab1", parent_table_name="tab_parent") # tab_parent is missing in schema with pytest.raises(ParentTableNotFoundException) as exc_val: schema.update_schema(tab1) diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 10d4ccaf6f..2beed89c92 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -134,6 +134,10 @@ def test_list_schemas(storage: SchemaStorage) -> None: assert set(storage.list_schemas()) == set(["ethereum", "event"]) storage.remove_schema("event") assert storage.list_schemas() == ["ethereum"] + # add schema with _ in the name + schema = Schema("dlt_pipeline") + storage.save_schema(schema) + assert set(storage.list_schemas()) == set(["ethereum", "dlt_pipeline"]) def test_remove_schema(storage: SchemaStorage) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index e6a875c644..a7f696b7d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,8 @@ def pytest_configure(config): run_configuration.RunConfiguration.config_files_storage_path = os.path.join(test_storage_root, "config/%s") load_volume_configuration.LoadVolumeConfiguration.load_volume_path = os.path.join(test_storage_root, "load") + delattr(load_volume_configuration.LoadVolumeConfiguration, "__init__") + load_volume_configuration.LoadVolumeConfiguration = dataclasses.dataclass(load_volume_configuration.LoadVolumeConfiguration, init=True, repr=False) normalize_volume_configuration.NormalizeVolumeConfiguration.normalize_volume_path = os.path.join(test_storage_root, "normalize") # delete __init__, otherwise it will not be recreated by dataclass diff --git a/tests/load/test_client.py b/tests/load/test_client.py index 89dba8dd45..699d0e2c6f 100644 --- a/tests/load/test_client.py +++ b/tests/load/test_client.py @@ -290,7 +290,7 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: str, fi child_table = client.schema.normalize_make_path(table_name, "child") # add child table without write disposition so it will be inferred from the parent client.schema.update_schema( - new_table(child_table, columns=TABLE_UPDATE, parent_name=table_name) + new_table(child_table, columns=TABLE_UPDATE, parent_table_name=table_name) ) client.schema.bump_version() client.update_storage_schema() diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index db8dab7230..16194212c3 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -10,11 +10,10 @@ from dlt.common.schema import Schema from dlt.common.storages import FileStorage, LoadStorage from dlt.common.storages.load_storage import JobWithUnsupportedWriterException -from dlt.common.typing import StrAny from dlt.common.utils import uniq_id +from dlt.common.destination import DestinationReference, LoadJob from dlt.load import Load -from dlt.load.client_base import DestinationReference, LoadJob from dlt.load.client_base_impl import LoadEmptyJob from dlt.load import dummy diff --git a/tests/load/utils.py b/tests/load/utils.py index 8cd0ff8ede..80d39def16 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1,11 +1,12 @@ import contextlib from importlib import import_module import os -from typing import Any, ContextManager, Iterable, Iterator, List, Sequence, cast, IO +from typing import Any, ContextManager, Iterator, List, Sequence, cast, IO from dlt.common import json, Decimal from dlt.common.configuration import resolve_configuration -from dlt.common.configuration.specs.schema_volume_configuration import SchemaVolumeConfiguration +from dlt.common.configuration.specs import SchemaVolumeConfiguration +from dlt.common.destination import DestinationClientDwhConfiguration, DestinationReference, JobClientBase, LoadJob from dlt.common.data_writers import DataWriter from dlt.common.schema import TColumnSchema, TTableSchemaColumns from dlt.common.storages import SchemaStorage, FileStorage @@ -15,9 +16,8 @@ from dlt.common.utils import uniq_id from dlt.load import Load -from dlt.load.client_base import DestinationReference, JobClientBase, LoadJob from dlt.load.client_base_impl import SqlJobClientBase -from dlt.load.configuration import DestinationClientDwhConfiguration + TABLE_UPDATE: List[TColumnSchema] = [ { diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index f9c2abbe6e..a4e3280ce7 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -6,12 +6,12 @@ from multiprocessing.dummy import Pool as ThreadPool from dlt.common import json -from dlt.common.configuration.specs.destination_capabilities_context import TLoaderFileFormat +from dlt.common.destination import TLoaderFileFormat from dlt.common.utils import uniq_id from dlt.common.typing import StrAny from dlt.common.schema import TDataType from dlt.common.storages import NormalizeStorage, LoadStorage -from dlt.common.configuration.specs import DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.configuration.container import Container from dlt.extract.extract import ExtractorStorage From ce9c3b9d51a62e1b791f2819008a5f4e6a7dcbdf Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 26 Oct 2022 21:43:29 +0200 Subject: [PATCH 49/66] adds ignored example secrets.toml --- .../pipeline/examples/project_structure/.dlt/secrets.toml | 6 ++++++ tests/common/cases/configuration/.dlt/secrets.toml | 0 2 files changed, 6 insertions(+) create mode 100644 experiments/pipeline/examples/project_structure/.dlt/secrets.toml create mode 100644 tests/common/cases/configuration/.dlt/secrets.toml diff --git a/experiments/pipeline/examples/project_structure/.dlt/secrets.toml b/experiments/pipeline/examples/project_structure/.dlt/secrets.toml new file mode 100644 index 0000000000..6844d188d5 --- /dev/null +++ b/experiments/pipeline/examples/project_structure/.dlt/secrets.toml @@ -0,0 +1,6 @@ +api_key="set me up" + +[destination.bigquery.credentials] +project_id="set me up" +private_key="set me up" +client_email="set me up" diff --git a/tests/common/cases/configuration/.dlt/secrets.toml b/tests/common/cases/configuration/.dlt/secrets.toml new file mode 100644 index 0000000000..e69de29bb2 From c4f6c8cf2143c0d1a6c0d7f861c36ed4f10a47c1 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 27 Oct 2022 22:15:15 +0200 Subject: [PATCH 50/66] implements toml config provider, changes how embedded and config namespaces work, config initial values behaving like any other values + tests --- .../configuration/providers/container.py | 4 +- .../configuration/providers/provider.py | 7 - dlt/common/configuration/providers/toml.py | 91 ++++++-- dlt/common/configuration/resolve.py | 199 ++++++++++-------- .../configuration/specs/base_configuration.py | 8 +- .../specs/config_providers_context.py | 13 +- .../specs/gcp_client_credentials.py | 2 - .../specs/postgres_credentials.py | 2 - dlt/common/destination.py | 4 +- dlt/common/pipeline.py | 16 ++ experiments/pipeline/__init__.py | 5 +- tests/.example.env | 14 +- .../cases/configuration/.dlt/config.toml | 27 +++ .../cases/configuration/.dlt/secrets.toml | 70 ++++++ .../configuration/test_configuration.py | 120 +++++------ tests/common/configuration/test_container.py | 7 +- tests/common/configuration/test_namespaces.py | 42 ++-- .../configuration/test_toml_provider.py | 169 +++++++++++++++ tests/common/configuration/utils.py | 61 +++++- tests/dbt_runner/test_runner_redshift.py | 2 +- .../bigquery/test_bigquery_table_builder.py | 4 +- .../redshift/test_redshift_table_builder.py | 2 +- tests/utils.py | 6 +- 23 files changed, 646 insertions(+), 229 deletions(-) create mode 100644 tests/common/cases/configuration/.dlt/config.toml create mode 100644 tests/common/configuration/test_toml_provider.py diff --git a/dlt/common/configuration/providers/container.py b/dlt/common/configuration/providers/container.py index 30699a40e5..cd1f1a7049 100644 --- a/dlt/common/configuration/providers/container.py +++ b/dlt/common/configuration/providers/container.py @@ -1,5 +1,5 @@ import contextlib -from typing import Any, Optional, Type, Tuple +from typing import Any, ClassVar, Optional, Type, Tuple from dlt.common.configuration.container import Container from dlt.common.configuration.specs import ContainerInjectableContext @@ -9,7 +9,7 @@ class ContextProvider(Provider): - NAME = "Injectable Context" + NAME: ClassVar[str] = "Injectable Context" def __init__(self) -> None: self.container = Container() diff --git a/dlt/common/configuration/providers/provider.py b/dlt/common/configuration/providers/provider.py index 0ecd69833c..9635734f91 100644 --- a/dlt/common/configuration/providers/provider.py +++ b/dlt/common/configuration/providers/provider.py @@ -4,8 +4,6 @@ class Provider(abc.ABC): - # def __init__(self) -> None: - # pass @abc.abstractmethod def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: @@ -25,8 +23,3 @@ def supports_namespaces(self) -> bool: @abc.abstractmethod def name(self) -> str: pass - - -def detect_known_providers() -> None: - # detects providers flagged - pass \ No newline at end of file diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index ee5a4c75aa..c4916d9b3a 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -1,23 +1,80 @@ import os -import inspect -import dataclasses +import abc import tomlkit -from inspect import Signature, Parameter -from typing import Any, List, Type -# from makefun import wraps -from functools import wraps +from typing import Any, Optional, Tuple, Type -from dlt.common.typing import DictStrAny, StrAny, TAny, TFun -from dlt.common.configuration import resolve_configuration, is_valid_hint -from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.typing import StrAny +from .provider import Provider -def _read_toml(file_name: str) -> StrAny: - config_file_path = os.path.abspath(os.path.join(".", "experiments/.dlt", file_name)) - if os.path.isfile(config_file_path): - with open(config_file_path, "r", encoding="utf-8") as f: - # use whitespace preserving parser - return tomlkit.load(f) - else: - return {} \ No newline at end of file +class TomlProvider(Provider): + + def __init__(self, file_name: str, project_dir: str = None) -> None: + self._file_name = file_name + self._toml_path = os.path.join(project_dir or os.path.abspath(os.path.join(".", ".dlt")), file_name) + self._toml = self._read_toml(self._toml_path) + + @staticmethod + def get_key_name(key: str, *namespaces: str) -> str: + # env key is always upper case + if namespaces: + namespaces = filter(lambda x: bool(x), namespaces) # type: ignore + env_key = ".".join((*namespaces, key)) + else: + env_key = key + return env_key + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + full_path = namespaces + (key,) + full_key = self.get_key_name(key, *namespaces) + node = self._toml + try: + for k in full_path: + node = node[k] + return node, full_key + except KeyError: + return None, full_key + + @property + def supports_namespaces(self) -> bool: + return True + + @staticmethod + def _read_toml(toml_path: str) -> StrAny: + if os.path.isfile(toml_path): + # TODO: raise an exception with an explanation to the end user what is this toml file that does not parse etc. + with open(toml_path, "r", encoding="utf-8") as f: + # use whitespace preserving parser + return tomlkit.load(f) + else: + return {} + + +class ConfigTomlProvider(TomlProvider): + + def __init__(self, project_dir: str = None) -> None: + super().__init__("config.toml", project_dir) + + @property + def name(self) -> str: + return "Pipeline config.toml" + + @property + def supports_secrets(self) -> bool: + return False + + + +class SecretsTomlProvider(TomlProvider): + + def __init__(self, project_dir: str = None) -> None: + super().__init__("secrets.toml", project_dir) + + @property + def name(self) -> str: + return "Pipeline secrets.toml" + + @property + def supports_secrets(self) -> bool: + return True diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index e9ae20207c..ea773c2d08 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -1,5 +1,4 @@ import ast -from contextlib import _GeneratorContextManager import inspect from collections.abc import Mapping as C_Mapping from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, get_origin @@ -8,10 +7,10 @@ from dlt.common.typing import TSecretValue, is_optional_type, extract_inner_type from dlt.common.schema.utils import coerce_type, py_type_to_sc_type -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration, ContainerInjectableContext from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersListContext +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.providers.container import ContextProvider from dlt.common.configuration.exceptions import (LookupTrace, ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException, ValueNotSecretException, InvalidInitialValue) @@ -26,49 +25,6 @@ def resolve_configuration(config: TConfiguration, *, namespaces: Tuple[str, ...] return _resolve_configuration(config, namespaces, (), initial_value, accept_partial) -def _resolve_configuration( - config: TConfiguration, - explicit_namespaces: Tuple[str, ...], - embedded_namespaces: Tuple[str, ...], - initial_value: Any, accept_partial: bool - ) -> TConfiguration: - # do not resolve twice - if config.is_resolved(): - return config - - config.__exception__ = None - try: - # parse initial value if possible - if initial_value is not None: - try: - config.from_native_representation(initial_value) - except (NotImplementedError, ValueError): - # if parsing failed and initial_values is dict then apply - # TODO: we may try to parse with json here if str - if isinstance(initial_value, C_Mapping): - config.update(initial_value) - else: - raise InvalidInitialValue(type(config), type(initial_value)) - - try: - _resolve_config_fields(config, explicit_namespaces, embedded_namespaces, accept_partial) - _check_configuration_integrity(config) - # full configuration was resolved - config.__is_resolved__ = True - except ConfigEntryMissingException as cm_ex: - if not accept_partial: - raise - else: - # store the ConfigEntryMissingException to have full info on traces of missing fields - config.__exception__ = cm_ex - except Exception as ex: - # store the exception that happened in the resolution process - config.__exception__ = ex - raise - - return config - - def deserialize_value(key: str, value: Any, hint: Type[Any]) -> Any: try: if hint != Any: @@ -128,13 +84,81 @@ def inject_namespace(namespace_context: ConfigNamespacesContext, merge_existing: return container.injectable_context(namespace_context) -# def _add_module_version(config: BaseConfiguration) -> None: -# try: -# v = sys._getframe(1).f_back.f_globals["__version__"] -# semver.VersionInfo.parse(v) -# setattr(config, "_version", v) # noqa: B010 -# except KeyError: -# pass +def _resolve_configuration( + config: TConfiguration, + explicit_namespaces: Tuple[str, ...], + embedded_namespaces: Tuple[str, ...], + initial_value: Any, + accept_partial: bool + ) -> TConfiguration: + # do not resolve twice + if config.is_resolved(): + return config + + config.__exception__ = None + try: + # if initial value is a Mapping then apply it + if isinstance(initial_value, C_Mapping): + config.update(initial_value) + # cannot be native initial value + initial_value = None + + # try to get the native representation of the configuration using the config namespace as a key + # allows, for example, to store connection string or service.json in their native form in single env variable or under single vault key + resolved_initial: Any = None + if config.__namespace__ or embedded_namespaces: + cf_n, emb_ns = _apply_embedded_namespaces_to_config_namespace(config.__namespace__, embedded_namespaces) + resolved_initial, traces = _resolve_single_field(cf_n, type(config), None, explicit_namespaces, emb_ns) + _log_traces(config, cf_n, type(config), resolved_initial, traces) + # initial values cannot be dictionaries + if not isinstance(resolved_initial, C_Mapping): + initial_value = resolved_initial or initial_value + # if this is injectable context then return it immediately + if isinstance(resolved_initial, ContainerInjectableContext): + return resolved_initial # type: ignore + try: + try: + # use initial value to set config values + if initial_value: + config.from_native_representation(initial_value) + # if no initial value or initial value was passed via argument, resolve config normally (config always over explicit params) + if not initial_value or not resolved_initial: + raise NotImplementedError() + except ValueError: + raise InvalidInitialValue(type(config), type(initial_value)) + except NotImplementedError: + # if config does not support native form, resolve normally + _resolve_config_fields(config, explicit_namespaces, embedded_namespaces, accept_partial) + + _check_configuration_integrity(config) + # full configuration was resolved + config.__is_resolved__ = True + except ConfigEntryMissingException as cm_ex: + if not accept_partial: + raise + else: + # store the ConfigEntryMissingException to have full info on traces of missing fields + config.__exception__ = cm_ex + except Exception as ex: + # store the exception that happened in the resolution process + config.__exception__ = ex + raise + + return config + + +def _apply_embedded_namespaces_to_config_namespace(config_namespace: str, embedded_namespaces: Tuple[str, ...]) -> Tuple[str, Tuple[str, ...]]: + # for the configurations that have __namespace__ (config_namespace) defined and are embedded in other configurations, + # the innermost embedded namespace replaces config_namespace + if embedded_namespaces: + config_namespace = embedded_namespaces[-1] + embedded_namespaces = embedded_namespaces[:-1] + # if config_namespace: + return config_namespace, embedded_namespaces + + +def _is_secret_hint(hint: Type[Any]) -> bool: + return hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration)) def _resolve_config_fields( @@ -154,39 +178,32 @@ def _resolve_config_fields( is_optional = is_optional_type(hint) # accept partial becomes True if type if optional so we do not fail on optional configs that do not resolve fully accept_partial = accept_partial or is_optional - # if actual value is BaseConfiguration, resolve that instance + + # if current value is BaseConfiguration, resolve that instance if isinstance(current_value, BaseConfiguration): # resolve only if not yet resolved otherwise just pass it if not current_value.is_resolved(): # add key as innermost namespace current_value = _resolve_configuration(current_value, explicit_namespaces, embedded_namespaces + (key,), None, accept_partial) else: - # resolve key value via active providers - value, traces = _resolve_single_field(key, hint, config.__namespace__, explicit_namespaces, embedded_namespaces) - - # log trace - if logger.is_logging() and logger.log_level() == "DEBUG": - logger.debug(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") - # print(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") - for tr in traces: - # print(str(tr)) - logger.debug(str(tr)) - # extract hint from Optional / Literal / NewType hints - hint = extract_inner_type(hint) + inner_hint = extract_inner_type(hint) # extract origin from generic types - hint = get_origin(hint) or hint - # if hint is BaseConfiguration then resolve it recursively - if inspect.isclass(hint) and issubclass(hint, BaseConfiguration): - if isinstance(value, BaseConfiguration): - # if value is base configuration already (ie. via ContainerProvider) return it directly - current_value = value - else: - # create new instance and pass value from the provider as initial, add key to namespaces - current_value = _resolve_configuration(hint(), explicit_namespaces, embedded_namespaces + (key,), value or current_value, accept_partial) + inner_hint = get_origin(inner_hint) or inner_hint + + # if inner_hint is BaseConfiguration then resolve it recursively + if inspect.isclass(inner_hint) and issubclass(inner_hint, BaseConfiguration): + # create new instance and pass value from the provider as initial, add key to namespaces + current_value = _resolve_configuration(inner_hint(), explicit_namespaces, embedded_namespaces + (key,), current_value, accept_partial) else: + + # resolve key value via active providers passing the original hint ie. to preserve TSecretValue + value, traces = _resolve_single_field(key, hint, config.__namespace__, explicit_namespaces, embedded_namespaces) + _log_traces(config, key, hint, value, traces) + # if value is resolved, then deserialize and coerce it if value is not None: - current_value = deserialize_value(key, value, hint) + current_value = deserialize_value(key, value, inner_hint) + # collect unresolved fields if not is_optional and current_value is None: unresolved_fields[key] = traces @@ -196,6 +213,15 @@ def _resolve_config_fields( raise ConfigEntryMissingException(type(config).__name__, unresolved_fields) +def _log_traces(config: BaseConfiguration, key: str, hint: Type[Any], value: Any, traces: Sequence[LookupTrace]) -> None: + if logger.is_logging() and logger.log_level() == "DEBUG": + logger.debug(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") + # print(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") + for tr in traces: + # print(str(tr)) + logger.debug(str(tr)) + + def _check_configuration_integrity(config: BaseConfiguration) -> None: # python multi-inheritance is cooperative and this would require that all configurations cooperatively # call each other check_integrity. this is not at all possible as we do not know which configs in the end will @@ -219,10 +245,10 @@ def _resolve_single_field( ) -> Tuple[Optional[Any], List[LookupTrace]]: container = Container() # get providers from container - providers = container[ConfigProvidersListContext].providers + providers = container[ConfigProvidersContext].providers # get additional namespaces to look in from container namespaces_context = container[ConfigNamespacesContext] - # pipeline_name = ctx_namespaces.pipeline_name + config_namespace, embedded_namespaces = _apply_embedded_namespaces_to_config_namespace(config_namespace, embedded_namespaces) # start looking from the top provider with most specific set of namespaces first traces: List[LookupTrace] = [] @@ -248,23 +274,26 @@ def look_namespaces(pipeline_name: str = None) -> Any: value = None while True: - if pipeline_name or config_namespace: + if (pipeline_name or config_namespace) and provider.supports_namespaces: full_ns = ns.copy() # pipeline, when provided, is the most outer and always present if pipeline_name: full_ns.insert(0, pipeline_name) - # config namespace, when provided, is innermost and always present - if config_namespace and provider.supports_namespaces: + # config namespace, is always present and innermost + if config_namespace: full_ns.append(config_namespace) else: full_ns = ns value, ns_key = provider.get_value(key, hint, *full_ns) - # create trace, ignore container provider - if provider.name != ContextProvider.NAME: - traces.append(LookupTrace(provider.name, full_ns, ns_key, value)) # if secret is obtained from non secret provider, we must fail - if value is not None and not provider.supports_secrets and (hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration))): + cant_hold_it: bool = not provider.supports_secrets and _is_secret_hint(hint) + if value is not None and cant_hold_it: raise ValueNotSecretException(provider.name, ns_key) + + # create trace, ignore container provider and providers that cant_hold_it + if provider.name != ContextProvider.NAME and not cant_hold_it: + traces.append(LookupTrace(provider.name, full_ns, ns_key, value)) + if value is not None: # value found, ignore other providers return value diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 2f9b1119ef..36d7c474d7 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -95,7 +95,7 @@ def from_native_representation(self, native_value: Any) -> None: NotImplementedError: This configuration does not have a native representation ValueError: The value provided cannot be parsed as native representation """ - raise ValueError() + raise NotImplementedError() def to_native_representation(self) -> Any: """Represents the configuration instance in its native form ie. database connection string or JSON serialized GCP service credentials file. @@ -106,7 +106,7 @@ def to_native_representation(self) -> Any: Returns: Any: A native representation of the configuration """ - raise ValueError() + raise NotImplementedError() def get_resolvable_fields(self) -> Dict[str, type]: """Returns a mapping of fields to their type hints. Dunder should not be resolved and are not returned""" @@ -167,7 +167,9 @@ def __fields_dict(self) -> Dict[str, TDtcField]: @configspec class CredentialsConfiguration(BaseConfiguration): """Base class for all credentials. Credentials are configurations that may be stored only by providers supporting secrets.""" - pass + + __namespace__: str = "credentials" + @configspec diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index 00dc9d7efb..f6576b8367 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -5,34 +5,35 @@ from dlt.common.configuration.providers import Provider from dlt.common.configuration.providers.environ import EnvironProvider from dlt.common.configuration.providers.container import ContextProvider +from dlt.common.configuration.providers.toml import SecretsTomlProvider, ConfigTomlProvider from dlt.common.configuration.specs.base_configuration import BaseConfiguration, ContainerInjectableContext, configspec @configspec -class ConfigProvidersListContext(ContainerInjectableContext): +class ConfigProvidersContext(ContainerInjectableContext): """Injectable list of providers used by the configuration `resolve` module""" providers: List[Provider] def __init__(self) -> None: super().__init__() # add default providers, ContextProvider must be always first - it will provide contexts - self.providers = [ContextProvider(), EnvironProvider()] + self.providers = [ContextProvider(), EnvironProvider(), SecretsTomlProvider(), ConfigTomlProvider()] - def get_provider(self, name: str) -> Provider: + def __getitem__(self, name: str) -> Provider: try: return next(p for p in self.providers if p.name == name) except StopIteration: raise KeyError(name) - def has_provider(self, name: str) -> bool: + def __contains__(self, name: object) -> bool: try: - self.get_provider(name) + self.__getitem__(name) # type: ignore return True except KeyError: return False def add_provider(self, provider: Provider) -> None: - if self.has_provider(provider.name): + if provider.name in self: raise DuplicateProviderException(provider.name) self.providers.append(provider) diff --git a/dlt/common/configuration/specs/gcp_client_credentials.py b/dlt/common/configuration/specs/gcp_client_credentials.py index 7d9ffd7695..857f0ab97c 100644 --- a/dlt/common/configuration/specs/gcp_client_credentials.py +++ b/dlt/common/configuration/specs/gcp_client_credentials.py @@ -8,8 +8,6 @@ @configspec class GcpClientCredentials(CredentialsConfiguration): - __namespace__: str = "gcp" - project_id: str = None type: str = "service_account" # noqa: A003 private_key: TSecretValue = None diff --git a/dlt/common/configuration/specs/postgres_credentials.py b/dlt/common/configuration/specs/postgres_credentials.py index cc745de284..42d2361183 100644 --- a/dlt/common/configuration/specs/postgres_credentials.py +++ b/dlt/common/configuration/specs/postgres_credentials.py @@ -7,8 +7,6 @@ @configspec class PostgresCredentials(CredentialsConfiguration): - __namespace__: str = "pg" - dbname: str = None password: TSecretValue = None user: str = None diff --git a/dlt/common/destination.py b/dlt/common/destination.py index 67ad218e5d..c3cf38b3f2 100644 --- a/dlt/common/destination.py +++ b/dlt/common/destination.py @@ -135,4 +135,6 @@ def resolve_destination_reference(destination: Union[None, str, DestinationRefer if isinstance(destination, str): # TODO: figure out if this is full module path name or name of one of the known destinations # if destination is a str, get destination reference by dynamically importing module from known location - return import_module(f"dlt.load.{destination}") \ No newline at end of file + return import_module(f"dlt.load.{destination}") + + return destination \ No newline at end of file diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index d9fd25d711..3be316dad6 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -1,3 +1,5 @@ +import os +import tempfile from typing import Any, Callable, ClassVar, Protocol, Sequence from dlt.common.configuration.container import ContainerInjectableContext @@ -43,3 +45,17 @@ def is_activated(self) -> bool: def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline]) -> None: self._deferred_pipeline = deferred_pipeline + + +def get_default_working_dir() -> str: + if os.geteuid() == 0: + # we are root so use standard /var + return os.path.join("/var", "dlt", "pipelines") + + home = os.path.expanduser("~") + if home is None: + # no home dir - use temp + return os.path.join(tempfile.gettempdir(), "dlt", "pipelines") + else: + # if home directory is available use ~/.dlt/pipelines + return os.path.join(home, ".dlt", "pipelines") diff --git a/experiments/pipeline/__init__.py b/experiments/pipeline/__init__.py index a9bc16ec79..53627f01ca 100644 --- a/experiments/pipeline/__init__.py +++ b/experiments/pipeline/__init__.py @@ -1,11 +1,10 @@ -import tempfile from typing import Union from dlt.common.typing import TSecretValue, Any from dlt.common.configuration import with_config from dlt.common.configuration.container import Container from dlt.common.destination import DestinationReference, resolve_destination_reference -from dlt.common.pipeline import PipelineContext +from dlt.common.pipeline import PipelineContext, get_default_working_dir from experiments.pipeline.configuration import PipelineConfiguration from experiments.pipeline.pipeline import Pipeline @@ -48,7 +47,7 @@ def pipeline( print(kwargs["_last_dlt_config"].pipeline_name) # if working_dir not provided use temp folder if not working_dir: - working_dir = tempfile.gettempdir() + working_dir = get_default_working_dir() destination = resolve_destination_reference(destination) # create new pipeline instance p = Pipeline(pipeline_name, working_dir, pipeline_secret, destination, dataset_name, import_schema_path, export_schema_path, always_drop_pipeline, kwargs["runtime"]) diff --git a/tests/.example.env b/tests/.example.env index 0a9a700dcf..c38ab0530d 100644 --- a/tests/.example.env +++ b/tests/.example.env @@ -4,14 +4,14 @@ DEFAULT_DATASET=carbon_bot_3 -GCP__PROJECT_ID=chat-analytics-317513 -GCP__PRIVATE_KEY="-----BEGIN PRIVATE KEY----- +CREDENTIALS__PROJECT_ID=chat-analytics-317513 +CREDENTIALS__PRIVATE_KEY="-----BEGIN PRIVATE KEY----- paste key here -----END PRIVATE KEY----- " -CLIENT_EMAIL=loader@chat-analytics-317513.iam.gserviceaccount.com +CREDENTIALS__CLIENT_EMAIL=loader@chat-analytics-317513.iam.gserviceaccount.com -PG__DBNAME=chat_analytics_rasa -PG__USER=loader -PG__HOST=3.73.90.3 -PG__PASSWORD=set-me-up \ No newline at end of file +CREDENTIALS__DBNAME=chat_analytics_rasa +CREDENTIALS__USER=loader +CREDENTIALS__HOST=3.73.90.3 +CREDENTIALS__PASSWORD=set-me-up \ No newline at end of file diff --git a/tests/common/cases/configuration/.dlt/config.toml b/tests/common/cases/configuration/.dlt/config.toml new file mode 100644 index 0000000000..13e287065f --- /dev/null +++ b/tests/common/cases/configuration/.dlt/config.toml @@ -0,0 +1,27 @@ +api_type="REST" + +api.url="http" +api.port=1024 + +[api.params] +param1="a" +param2="b" + +[typecheck] +str_val="test string" +int_val=12345 +bool_val=true +list_val=[1, "2", [3]] +dict_val={'a'=1, "b"="2"} +float_val=1.18927 +tuple_val=[1, 2, {1="complicated dicts allowed in literal eval"}] +COMPLEX_VAL={"_"= [1440, ["*"], []], "change-email"= [560, ["*"], []]} +date_val=1979-05-27T07:32:00-08:00 +dec_val="22.38" # always use text to pass decimals +bytes_val="0x48656c6c6f20576f726c6421" # always use text to pass hex value that should be converted to bytes +any_val="function() {}" +none_val="none" +sequence_val=["A", "B", "KAPPA"] +gen_list_val=["C", "Z", "N"] +mapping_val={"FL"=1, "FR"={"1"=2}} +mutable_mapping_val={"str"="str"} diff --git a/tests/common/cases/configuration/.dlt/secrets.toml b/tests/common/cases/configuration/.dlt/secrets.toml index e69de29bb2..42e11d46dd 100644 --- a/tests/common/cases/configuration/.dlt/secrets.toml +++ b/tests/common/cases/configuration/.dlt/secrets.toml @@ -0,0 +1,70 @@ +secret_value="2137" +api.port=1023 + +# holds a literal string that can be parsed as gcp credentials +source.credentials=''' +{ + "type": "service_account", + "project_id": "mock-project-id-source.credentials", + "private_key_id": "62c1f8f00836dec27c8d96d1c0836df2c1f6bce4", + "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n...\n-----END PRIVATE KEY-----\n", + "client_email": "loader@a7513.iam.gserviceaccount.com", + "client_id": "114701312674477307596", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/loader%40mock-project-id-2.iam.gserviceaccount.com", + "file_upload_timeout": 819872989 + } +''' + +[credentials] +secret_value="2137" +"project_id"="mock-project-id-credentials" + +[gcp_storage] +"project_id"="mock-project-id-gcp-storage" +"private_key"="-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n....\n-----END PRIVATE KEY-----\n" +"client_email"="loader@a7513.iam.gserviceaccount.com" + +[destination.redshift.credentials] +dbname="destination.redshift.credentials" +user="loader" +host="3.73.90.3" +password="set-me-up" + +[destination.credentials] +"type"="service_account" +"project_id"="mock-project-id-destination.credentials" +"private_key_id"="62c1f8f00836dec27c8d96d1c0836df2c1f6bce4" +"private_key"="-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n....\n-----END PRIVATE KEY-----\n" +"client_email"="loader@a7513.iam.gserviceaccount.com" +"client_id"="114701312674477307596" +"auth_uri"="https://accounts.google.com/o/oauth2/auth" +"token_uri"="https://oauth2.googleapis.com/token" +"auth_provider_x509_cert_url"="https://www.googleapis.com/oauth2/v1/certs" +"client_x509_cert_url"="https://www.googleapis.com/robot/v1/metadata/x509/loader%40mock-project-id-1.iam.gserviceaccount.com" + +[destination.bigquery] +"type"="service_account" +"project_id"="mock-project-id-destination.bigquery" +"private_key_id"="62c1f8f00836dec27c8d96d1c0836df2c1f6bce4" +"private_key"="-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n....\n-----END PRIVATE KEY-----\n" +"client_email"="loader@a7513.iam.gserviceaccount.com" +"client_id"="114701312674477307596" +"auth_uri"="https://accounts.google.com/o/oauth2/auth" +"token_uri"="https://oauth2.googleapis.com/token" +"auth_provider_x509_cert_url"="https://www.googleapis.com/oauth2/v1/certs" +"client_x509_cert_url"="https://www.googleapis.com/robot/v1/metadata/x509/loader%40mock-project-id-1.iam.gserviceaccount.com" + +[destination.bigquery.credentials] +"type"="service_account" +"project_id"="mock-project-id-destination.bigquery.credentials" +"private_key_id"="62c1f8f00836dec27c8d96d1c0836df2c1f6bce4" +"private_key"="-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n....\n-----END PRIVATE KEY-----\n" +"client_email"="loader@a7513.iam.gserviceaccount.com" +"client_id"="114701312674477307596" +"auth_uri"="https://accounts.google.com/o/oauth2/auth" +"token_uri"="https://oauth2.googleapis.com/token" +"auth_provider_x509_cert_url"="https://www.googleapis.com/oauth2/v1/certs" +"client_x509_cert_url"="https://www.googleapis.com/robot/v1/metadata/x509/loader%40mock-project-id-1.iam.gserviceaccount.com" diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index d6437c2fbc..a442e49366 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -1,44 +1,18 @@ import pytest import datetime # noqa: I251 -from typing import Any, Dict, List, Mapping, MutableMapping, NewType, Optional, Sequence, Tuple, Type +from typing import Any, Dict, List, Mapping, MutableMapping, NewType, Optional, Type from dlt.common import pendulum, Decimal, Wei from dlt.common.utils import custom_environ -from dlt.common.typing import StrAny, TSecretValue, extract_inner_type +from dlt.common.typing import TSecretValue, extract_inner_type from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, InvalidInitialValue, LookupTrace, ValueNotSecretException from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, resolve from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration from dlt.common.configuration.specs.base_configuration import is_valid_hint -from dlt.common.configuration.providers import environ as environ_provider +from dlt.common.configuration.providers import environ as environ_provider, toml from tests.utils import preserve_environ, add_config_dict_to_env -from tests.common.configuration.utils import MockProvider, WithCredentialsConfiguration, WrongConfiguration, SecretConfiguration, NamespacedConfiguration, environment, mock_provider - -COERCIONS = { - 'str_val': 'test string', - 'int_val': 12345, - 'bool_val': True, - 'list_val': [1, "2", [3]], - 'dict_val': { - 'a': 1, - "b": "2" - }, - 'bytes_val': b'Hello World!', - 'float_val': 1.18927, - "tuple_val": (1, 2, {1: "complicated dicts allowed in literal eval"}), - 'any_val': "function() {}", - 'none_val': "none", - 'COMPLEX_VAL': { - "_": [1440, ["*"], []], - "change-email": [560, ["*"], []] - }, - "date_val": pendulum.now(), - "dec_val": Decimal("22.38"), - "sequence_val": ["A", "B", "KAPPA"], - "gen_list_val": ["C", "Z", "N"], - "mapping_val": {"FL": 1, "FR": {"1": 2}}, - "mutable_mapping_val": {"str": "str"} -} +from tests.common.configuration.utils import MockProvider, CoercionTestConfiguration, COERCIONS, WithCredentialsConfiguration, WrongConfiguration, SecretConfiguration, NamespacedConfiguration, environment, mock_provider INVALID_COERCIONS = { # 'STR_VAL': 'test string', # string always OK @@ -68,28 +42,6 @@ } -@configspec -class CoercionTestConfiguration(RunConfiguration): - pipeline_name: str = "Some Name" - str_val: str = None - int_val: int = None - bool_val: bool = None - list_val: list = None # type: ignore - dict_val: dict = None # type: ignore - bytes_val: bytes = None - float_val: float = None - tuple_val: Tuple[int, int, StrAny] = None - any_val: Any = None - none_val: str = None - COMPLEX_VAL: Dict[str, Tuple[int, List[str], List[str]]] = None - date_val: datetime.datetime = None - dec_val: Decimal = None - sequence_val: Sequence[str] = None - gen_list_val: List[str] = None - mapping_val: StrAny = None - mutable_mapping_val: MutableMapping[str, str] = None - - @configspec class VeryWrongConfiguration(WrongConfiguration): pipeline_name: str = "Some Name" @@ -169,25 +121,53 @@ class EmbeddedSecretConfiguration(BaseConfiguration): def test_initial_config_state() -> None: assert BaseConfiguration.__is_resolved__ is False assert BaseConfiguration.__namespace__ is None - C = BaseConfiguration() - assert C.__is_resolved__ is False - assert C.is_resolved() is False + c = BaseConfiguration() + assert c.__is_resolved__ is False + assert c.is_resolved() is False # base configuration has no resolvable fields so is never partial - assert C.is_partial() is False + assert c.is_partial() is False def test_set_initial_config_value(environment: Any) -> None: # set from init method - C = resolve.resolve_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) - assert C.to_native_representation() == "h>a>b>he" + c = resolve.resolve_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) + assert c.to_native_representation() == "h>a>b>he" # set from native form - C = resolve.resolve_configuration(InstrumentedConfiguration(), initial_value="h>a>b>he") - assert C.head == "h" - assert C.tube == ["a", "b"] - assert C.heels == "he" + c = resolve.resolve_configuration(InstrumentedConfiguration(), initial_value="h>a>b>he") + assert c.head == "h" + assert c.tube == ["a", "b"] + assert c.heels == "he" # set from dictionary - C = resolve.resolve_configuration(InstrumentedConfiguration(), initial_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) - assert C.to_native_representation() == "h>tu>be>xhe" + c = resolve.resolve_configuration(InstrumentedConfiguration(), initial_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) + assert c.to_native_representation() == "h>tu>be>xhe" + + +def test_initial_native_representation_skips_resolve(environment: Any) -> None: + c = InstrumentedConfiguration() + # mock namespace to enable looking for initials in provider + c.__namespace__ = "ins" + # explicit initial does not skip resolve + environment["INS__HEELS"] = "xhe" + c = resolve.resolve_configuration(c, initial_value="h>a>b>he") + assert c.heels == "xhe" + + # now put the whole native representation in env + environment["INS"] = "h>a>b>he" + c = InstrumentedConfiguration() + c.__namespace__ = "ins" + c = resolve.resolve_configuration(c, initial_value="h>a>b>uhe") + assert c.heels == "he" + + +def test_query_initial_config_value_if_config_namespace(environment: Any) -> None: + c = InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he") + # mock the __namespace__ to enable the query + c.__namespace__ = "snake" + # provide the initial value + environment["SNAKE"] = "h>tu>be>xhe" + c = resolve.resolve_configuration(c) + # check if the initial value loaded + assert c.heels == "xhe" def test_invalid_initial_config_value() -> None: @@ -212,7 +192,7 @@ def test_embedded_config(environment: Any) -> None: assert C.namespaced.password == "pwd" # resolve but providing values via env - with custom_environ({"INSTRUMENTED": "h>tu>u>be>xhe", "DLT_TEST__PASSWORD": "passwd", "DEFAULT": "DEF"}): + with custom_environ({"INSTRUMENTED": "h>tu>u>be>xhe", "NAMESPACED__PASSWORD": "passwd", "DEFAULT": "DEF"}): C = resolve.resolve_configuration(EmbeddedConfiguration()) assert C.default == "DEF" assert C.instrumented.to_native_representation() == "h>tu>u>be>xhe" @@ -225,7 +205,7 @@ def test_embedded_config(environment: Any) -> None: assert not C.instrumented.__is_resolved__ # some are partial, some are not - with custom_environ({"DLT_TEST__PASSWORD": "passwd"}): + with custom_environ({"NAMESPACED__PASSWORD": "passwd"}): C = resolve.resolve_configuration(EmbeddedConfiguration(), accept_partial=True) assert not C.__is_resolved__ assert C.namespaced.__is_resolved__ @@ -357,8 +337,10 @@ def test_raises_on_unresolved_field(environment: Any) -> None: assert "NoneConfigVar" in cf_missing_exc.value.traces # has only one trace trace = cf_missing_exc.value.traces["NoneConfigVar"] - assert len(trace) == 1 + assert len(trace) == 3 assert trace[0] == LookupTrace("Environment Variables", [], "NONECONFIGVAR", None) + assert trace[1] == LookupTrace("Pipeline secrets.toml", [], "NoneConfigVar", None) + assert trace[2] == LookupTrace("Pipeline config.toml", [], "NoneConfigVar", None) def test_raises_on_many_unresolved_fields(environment: Any) -> None: @@ -371,8 +353,10 @@ def test_raises_on_many_unresolved_fields(environment: Any) -> None: traces = cf_missing_exc.value.traces assert len(traces) == len(val_fields) for tr_field, exp_field in zip(traces, val_fields): - assert len(traces[tr_field]) == 1 + assert len(traces[tr_field]) == 3 assert traces[tr_field][0] == LookupTrace("Environment Variables", [], environ_provider.EnvironProvider.get_key_name(exp_field), None) + assert traces[tr_field][1] == LookupTrace("Pipeline secrets.toml", [], toml.TomlProvider.get_key_name(exp_field), None) + assert traces[tr_field][2] == LookupTrace("Pipeline config.toml", [], toml.TomlProvider.get_key_name(exp_field), None) def test_accepts_optional_missing_fields(environment: Any) -> None: diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py index 786a04494f..2581174e93 100644 --- a/tests/common/configuration/test_container.py +++ b/tests/common/configuration/test_container.py @@ -7,7 +7,7 @@ from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ContainerInjectableContextMangled, InvalidInitialValue, ContextDefaultCannotBeCreated -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersListContext +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from tests.utils import preserve_environ from tests.common.configuration.utils import environment @@ -17,6 +17,9 @@ class InjectableTestContext(ContainerInjectableContext): current_value: str + def from_native_representation(self, native_value: Any) -> None: + raise ValueError(native_value) + @configspec class EmbeddedWithInjectableContext(BaseConfiguration): @@ -141,7 +144,7 @@ def test_container_provider_embedded_inject(container: Container, environment: A assert C.injected.current_value == "Embed" assert C.injected is injected # remove first provider - container[ConfigProvidersListContext].providers.pop(0) + container[ConfigProvidersContext].providers.pop(0) # now environment will provide unparsable value with pytest.raises(InvalidInitialValue): C = resolve_configuration(EmbeddedWithInjectableContext()) diff --git a/tests/common/configuration/test_namespaces.py b/tests/common/configuration/test_namespaces.py index e8f462e90c..24d28568e2 100644 --- a/tests/common/configuration/test_namespaces.py +++ b/tests/common/configuration/test_namespaces.py @@ -21,6 +21,11 @@ class EmbeddedConfiguration(BaseConfiguration): sv_config: Optional[SingleValConfiguration] +@configspec +class EmbeddedWithNamespacedConfiguration(BaseConfiguration): + embedded: NamespacedConfiguration + + def test_namespaced_configuration(environment: Any) -> None: with pytest.raises(ConfigEntryMissingException) as exc_val: resolve.resolve_configuration(NamespacedConfiguration()) @@ -30,8 +35,10 @@ def test_namespaced_configuration(environment: Any) -> None: # check trace traces = exc_val.value.traces["password"] # only one provider and namespace was tried - assert len(traces) == 1 + assert len(traces) == 3 assert traces[0] == LookupTrace("Environment Variables", ["DLT_TEST"], "DLT_TEST__PASSWORD", None) + assert traces[1] == LookupTrace("Pipeline secrets.toml", ["DLT_TEST"], "DLT_TEST.password", None) + assert traces[2] == LookupTrace("Pipeline config.toml", ["DLT_TEST"], "DLT_TEST.password", None) # init vars work without namespace C = resolve.resolve_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) @@ -77,28 +84,39 @@ def test_explicit_namespaces_with_namespaced_config(mock_provider: MockProvider) mock_provider.return_value_on = ("DLT_TEST",) resolve.resolve_configuration(NamespacedConfiguration()) assert mock_provider.last_namespace == ("DLT_TEST",) - # namespace from config is mandatory, provider will not be queried with () - assert mock_provider.last_namespaces == [("DLT_TEST",)] + # first the native representation of NamespacedConfiguration is queried with (), and then the fields in NamespacedConfiguration are queried only in DLT_TEST + assert mock_provider.last_namespaces == [(), ("DLT_TEST",)] # namespaced config is always innermost mock_provider.reset_stats() resolve.resolve_configuration(NamespacedConfiguration(), namespaces=("ns1",)) - assert mock_provider.last_namespaces == [("ns1", "DLT_TEST"), ("DLT_TEST",)] + assert mock_provider.last_namespaces == [("ns1",), (), ("ns1", "DLT_TEST"), ("DLT_TEST",)] mock_provider.reset_stats() resolve.resolve_configuration(NamespacedConfiguration(), namespaces=("ns1", "ns2")) - assert mock_provider.last_namespaces == [("ns1", "ns2", "DLT_TEST"), ("ns1", "DLT_TEST"), ("DLT_TEST",)] + assert mock_provider.last_namespaces == [("ns1", "ns2"), ("ns1",), (), ("ns1", "ns2", "DLT_TEST"), ("ns1", "DLT_TEST"), ("DLT_TEST",)] + + +def test_overwrite_config_namespace_from_embedded(mock_provider: MockProvider) -> None: + mock_provider.value = {} + mock_provider.return_value_on = ("embedded",) + resolve.resolve_configuration(EmbeddedWithNamespacedConfiguration()) + # when resolving the config namespace DLT_TEST was removed and the embedded namespace was used instead + assert mock_provider.last_namespace == ("embedded",) + # lookup in order: () - parent config when looking for "embedded", then from "embedded" config + assert mock_provider.last_namespaces == [(), ("embedded",)] def test_explicit_namespaces_from_embedded_config(mock_provider: MockProvider) -> None: mock_provider.value = {"sv": "A"} + mock_provider.return_value_on = ("sv_config",) C = resolve.resolve_configuration(EmbeddedConfiguration()) # we mock the dictionary below as the value for all requests assert C.sv_config.sv == '{"sv": "A"}' - # following namespaces were used when resolving EmbeddedConfig: () - to resolve sv_config and then: ("sv_config",), () to resolve sv in sv_config - assert mock_provider.last_namespaces == [(), ("sv_config",), ()] + # following namespaces were used when resolving EmbeddedConfig: () trying to get initial value for the whole embedded sv_config, then ("sv_config",), () to resolve sv in sv_config + assert mock_provider.last_namespaces == [(), ("sv_config",)] # embedded namespace inner of explicit mock_provider.reset_stats() C = resolve.resolve_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) - assert mock_provider.last_namespaces == [("ns1",), (), ("ns1", "sv_config",), ("ns1",), ()] + assert mock_provider.last_namespaces == [("ns1",), (), ("ns1", "sv_config",), ("sv_config",)] def test_injected_namespaces(mock_provider: MockProvider) -> None: @@ -116,14 +134,14 @@ def test_injected_namespaces(mock_provider: MockProvider) -> None: mock_provider.reset_stats() mock_provider.return_value_on = ("DLT_TEST",) resolve.resolve_configuration(NamespacedConfiguration()) - assert mock_provider.last_namespaces == [("inj-ns1", "DLT_TEST"), ("DLT_TEST",)] + assert mock_provider.last_namespaces == [("inj-ns1",), (), ("inj-ns1", "DLT_TEST"), ("DLT_TEST",)] # injected namespace inner of ns coming from embedded config mock_provider.reset_stats() mock_provider.return_value_on = () mock_provider.value = {"sv": "A"} resolve.resolve_configuration(EmbeddedConfiguration()) # first we look for sv_config -> ("inj-ns1",), () then we look for sv - assert mock_provider.last_namespaces == [("inj-ns1", ), (), ("inj-ns1", "sv_config"), ("inj-ns1",), ()] + assert mock_provider.last_namespaces == [("inj-ns1",), (), ("inj-ns1", "sv_config"), ("sv_config",)] # multiple injected namespaces with container.injectable_context(ConfigNamespacesContext(namespaces=("inj-ns1", "inj-ns2"))): @@ -134,7 +152,6 @@ def test_injected_namespaces(mock_provider: MockProvider) -> None: def test_namespace_with_pipeline_name(mock_provider: MockProvider) -> None: - # AXIES__DESTINATION__STORAGE_CREDENTIALS__PRIVATE_KEY, DESTINATION__STORAGE_CREDENTIALS__PRIVATE_KEY, DESTINATION__PRIVATE_KEY, GCP__PRIVATE_KEY # if pipeline name is present, keys will be looked up twice: with pipeline as top level namespace and without it container = Container() @@ -165,7 +182,8 @@ def test_namespace_with_pipeline_name(mock_provider: MockProvider) -> None: mock_provider.return_value_on = ("DLT_TEST",) mock_provider.reset_stats() resolve.resolve_configuration(NamespacedConfiguration()) - assert mock_provider.last_namespaces == [("PIPE", "DLT_TEST"), ("DLT_TEST",)] + # first the whole NamespacedConfiguration is looked under key DLT_TEST (namespaces: ('PIPE',), ()), then fields of NamespacedConfiguration + assert mock_provider.last_namespaces == [('PIPE',), (), ("PIPE", "DLT_TEST"), ("DLT_TEST",)] # with pipeline and injected namespaces with container.injectable_context(ConfigNamespacesContext(pipeline_name="PIPE", namespaces=("inj-ns1",))): diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py new file mode 100644 index 0000000000..6a8128154c --- /dev/null +++ b/tests/common/configuration/test_toml_provider.py @@ -0,0 +1,169 @@ +import pytest +from typing import Any, Iterator +import datetime # noqa: I251 + + +from dlt.common import pendulum +from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, resolve +from dlt.common.configuration.container import Container +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.exceptions import LookupTrace +from dlt.common.configuration.providers.environ import EnvironProvider +from dlt.common.configuration.providers.toml import SecretsTomlProvider, ConfigTomlProvider +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs import BaseConfiguration, GcpClientCredentials, PostgresCredentials +from dlt.common.typing import TSecretValue + +from tests.utils import preserve_environ +from tests.common.configuration.utils import WithCredentialsConfiguration, CoercionTestConfiguration, COERCIONS, SecretConfiguration, environment + + +@configspec +class EmbeddedWithGcpStorage(BaseConfiguration): + gcp_storage: GcpClientCredentials + + +@configspec +class EmbeddedWithGcpCredentials(BaseConfiguration): + credentials: GcpClientCredentials + + +@pytest.fixture +def providers() -> Iterator[ConfigProvidersContext]: + pipeline_root = "./tests/common/cases/configuration/.dlt" + ctx = ConfigProvidersContext() + ctx.providers.clear() + ctx.add_provider(SecretsTomlProvider(project_dir=pipeline_root)) + ctx.add_provider(ConfigTomlProvider(project_dir=pipeline_root)) + with Container().injectable_context(ctx): + yield ctx + + +def test_secrets_from_toml_secrets() -> None: + with pytest.raises(ConfigEntryMissingException) as py_ex: + resolve.resolve_configuration(SecretConfiguration()) + + # only two traces because TSecretValue won't be checked in config.toml provider + traces = py_ex.value.traces["secret_value"] + assert len(traces) == 2 + assert traces[0] == LookupTrace("Environment Variables", [], "SECRET_VALUE", None) + assert traces[1] == LookupTrace("Pipeline secrets.toml", [], "secret_value", None) + + with pytest.raises(ConfigEntryMissingException) as py_ex: + resolve.resolve_configuration(WithCredentialsConfiguration()) + + +def test_toml_types(providers: ConfigProvidersContext) -> None: + # resolve CoercionTestConfiguration from typecheck namespace + c = resolve.resolve_configuration(CoercionTestConfiguration(), namespaces=("typecheck",)) + for k, v in COERCIONS.items(): + # toml does not know tuples + if isinstance(v, tuple): + v = list(v) + if isinstance(v, datetime.datetime): + v = pendulum.parse("1979-05-27T07:32:00-08:00") + assert v == c[k] + + +def test_config_provider_order(providers: ConfigProvidersContext, environment: Any) -> None: + + # add env provider + providers.providers.insert(0, EnvironProvider()) + + @with_config(namespaces=("api",)) + def single_val(port): + return port + + # secrets have api.port=1023 and this will be used + assert single_val() == 1023 + + # env will make it string, also namespace is optional + environment["PORT"] = "UNKNOWN" + assert single_val() == "UNKNOWN" + + environment["API__PORT"] = "1025" + assert single_val() == "1025" + + +def test_toml_mixed_config_inject(providers: ConfigProvidersContext) -> None: + # get data from both providers + + @with_config + def mixed_val(api_type, secret_value: TSecretValue, typecheck: Any): + return api_type, secret_value, typecheck + + _tup = mixed_val() + assert _tup[0] == "REST" + assert _tup[1] == "2137" + assert isinstance(_tup[2], dict) + + +def test_toml_namespaces(providers: ConfigProvidersContext) -> None: + cfg = providers["Pipeline config.toml"] + assert cfg.get_value("api_type", str) == ("REST", "api_type") + assert cfg.get_value("port", int, "api") == (1024, "api.port") + assert cfg.get_value("param1", str, "api", "params") == ("a", "api.params.param1") + + +def test_secrets_toml_credentials(providers: ConfigProvidersContext) -> None: + # there are credentials exactly under destination.bigquery.credentials + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination", "bigquery")) + assert c.project_id.endswith("destination.bigquery.credentials") + # there are no destination.gcp_storage.credentials so it will fallback to "destination"."credentials" + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination", "gcp_storage")) + assert c.project_id.endswith("destination.credentials") + # also explicit + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination",)) + assert c.project_id.endswith("destination.credentials") + # there's "credentials" key but does not contain valid gcp credentials + with pytest.raises(ConfigEntryMissingException): + resolve.resolve_configuration(GcpClientCredentials()) + # also try postgres credentials + c = resolve.resolve_configuration(PostgresCredentials(), namespaces=("destination", "redshift")) + assert c.dbname == "destination.redshift.credentials" + # bigquery credentials do not match redshift credentials + with pytest.raises(ConfigEntryMissingException): + resolve.resolve_configuration(PostgresCredentials(), namespaces=("destination", "bigquery")) + + + +def test_secrets_toml_embedded_credentials(providers: ConfigProvidersContext) -> None: + # will try destination.bigquery.credentials + c = resolve.resolve_configuration(EmbeddedWithGcpCredentials(), namespaces=("destination", "bigquery")) + assert c.credentials.project_id.endswith("destination.bigquery.credentials") + # will try destination.gcp_storage.credentials and fallback to destination.credentials + c = resolve.resolve_configuration(EmbeddedWithGcpCredentials(), namespaces=("destination", "gcp_storage")) + assert c.credentials.project_id.endswith("destination.credentials") + # will try everything until credentials in the root where incomplete credentials are present + c = EmbeddedWithGcpCredentials() + # create embedded config that will be passed as initial + c.credentials = GcpClientCredentials() + with pytest.raises(ConfigEntryMissingException) as py_ex: + resolve.resolve_configuration(c, namespaces=("middleware", "storage")) + # so we can read partially filled configuration here + assert c.credentials.project_id.endswith("-credentials") + assert set(py_ex.value.traces.keys()) == {"client_email", "private_key"} + + # embed "gcp_storage" will bubble up to the very top, never reverts to "credentials" + c = resolve.resolve_configuration(EmbeddedWithGcpStorage(), namespaces=("destination", "bigquery")) + assert c.gcp_storage.project_id.endswith("-gcp-storage") + + # also explicit + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination",)) + assert c.project_id.endswith("destination.credentials") + # there's "credentials" key but does not contain valid gcp credentials + with pytest.raises(ConfigEntryMissingException): + resolve.resolve_configuration(GcpClientCredentials()) + + +# def test_secrets_toml_ignore_dict_initial(providers: ConfigProvidersContext) -> None: + + + +def test_secrets_toml_credentials_from_native_repr(providers: ConfigProvidersContext) -> None: + # cfg = providers["Pipeline secrets.toml"] + # print(cfg._toml) + # print(cfg._toml["source"]["credentials"]) + # resolve gcp_credentials by parsing initial value which is str holding json doc + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("source",)) + assert c.project_id.endswith("source.credentials") diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py index 02d028b1be..9ba58bb5e8 100644 --- a/tests/common/configuration/utils.py +++ b/tests/common/configuration/utils.py @@ -1,10 +1,12 @@ import pytest from os import environ -from typing import Any, List, Optional, Tuple, Type -from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersListContext +import datetime # noqa: I251 +from typing import Any, List, Optional, Tuple, Type, Dict, MutableMapping, Optional, Sequence -from dlt.common.typing import TSecretValue +from dlt.common import Decimal, pendulum +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.typing import TSecretValue, StrAny from dlt.common.configuration import configspec from dlt.common.configuration.providers import Provider from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration, RunConfiguration @@ -17,6 +19,28 @@ class WrongConfiguration(RunConfiguration): log_color: bool = True +@configspec +class CoercionTestConfiguration(RunConfiguration): + pipeline_name: str = "Some Name" + str_val: str = None + int_val: int = None + bool_val: bool = None + list_val: list = None # type: ignore + dict_val: dict = None # type: ignore + bytes_val: bytes = None + float_val: float = None + tuple_val: Tuple[int, int, StrAny] = None + any_val: Any = None + none_val: str = None + COMPLEX_VAL: Dict[str, Tuple[int, List[str], List[str]]] = None + date_val: datetime.datetime = None + dec_val: Decimal = None + sequence_val: Sequence[str] = None + gen_list_val: List[str] = None + mapping_val: StrAny = None + mutable_mapping_val: MutableMapping[str, str] = None + + @configspec class SecretConfiguration(BaseConfiguration): secret_value: TSecretValue = None @@ -48,7 +72,7 @@ def environment() -> Any: @pytest.fixture(scope="function") def mock_provider() -> "MockProvider": container = Container() - with container.injectable_context(ConfigProvidersListContext()) as providers: + with container.injectable_context(ConfigProvidersContext()) as providers: # replace all providers with MockProvider that does not support secrets mock_provider = MockProvider() providers.providers = [mock_provider] @@ -93,3 +117,30 @@ class SecretMockProvider(MockProvider): @property def supports_secrets(self) -> bool: return True + + +COERCIONS = { + 'str_val': 'test string', + 'int_val': 12345, + 'bool_val': True, + 'list_val': [1, "2", [3]], + 'dict_val': { + 'a': 1, + "b": "2" + }, + 'bytes_val': b'Hello World!', + 'float_val': 1.18927, + "tuple_val": (1, 2, {"1": "complicated dicts allowed in literal eval"}), + 'any_val': "function() {}", + 'none_val': "none", + 'COMPLEX_VAL': { + "_": [1440, ["*"], []], + "change-email": [560, ["*"], []] + }, + "date_val": pendulum.now(), + "dec_val": Decimal("22.38"), + "sequence_val": ["A", "B", "KAPPA"], + "gen_list_val": ["C", "Z", "N"], + "mapping_val": {"FL": 1, "FR": {"1": 2}}, + "mutable_mapping_val": {"str": "str"} +} \ No newline at end of file diff --git a/tests/dbt_runner/test_runner_redshift.py b/tests/dbt_runner/test_runner_redshift.py index 8e650602df..0c915ea709 100644 --- a/tests/dbt_runner/test_runner_redshift.py +++ b/tests/dbt_runner/test_runner_redshift.py @@ -25,7 +25,7 @@ @pytest.fixture(scope="module", autouse=True) def module_autouse() -> None: # disable GCP in environ - del environ["GCP__PROJECT_ID"] + del environ["CREDENTIALS__PROJECT_ID"] # set the test case for the unit tests environ["DEFAULT_DATASET"] = "test_fixture_carbon_bot_session_cases" add_config_to_env(PostgresCredentials) diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 73a9d31fe7..a8ffe2b02b 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -21,11 +21,11 @@ def schema() -> Schema: def test_configuration() -> None: # check names normalized - with custom_environ({"GCP__PRIVATE_KEY": "---NO NEWLINE---\n"}): + with custom_environ({"CREDENTIALS__PRIVATE_KEY": "---NO NEWLINE---\n"}): C = resolve_configuration(GcpClientCredentials()) assert C.private_key == "---NO NEWLINE---\n" - with custom_environ({"GCP__PRIVATE_KEY": "---WITH NEWLINE---\n"}): + with custom_environ({"CREDENTIALS__PRIVATE_KEY": "---WITH NEWLINE---\n"}): C = resolve_configuration(GcpClientCredentials()) assert C.private_key == "---WITH NEWLINE---\n" diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index e007d8b37a..d89ed5aa71 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -27,7 +27,7 @@ def client(schema: Schema) -> RedshiftClient: def test_configuration() -> None: # check names normalized - with custom_environ({"PG__DBNAME": "UPPER_CASE_DATABASE", "PG__PASSWORD": " pass\n"}): + with custom_environ({"CREDENTIALS__DBNAME": "UPPER_CASE_DATABASE", "CREDENTIALS__PASSWORD": " pass\n"}): C = resolve_configuration(PostgresCredentials()) assert C.dbname == "upper_case_database" assert C.password == "pass" diff --git a/tests/utils.py b/tests/utils.py index b586b8ef3b..17dc6884da 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,7 +10,7 @@ from dlt.common.configuration.providers import EnvironProvider, DictionaryProvider from dlt.common.configuration.resolve import resolve_configuration, serialize_value from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersListContext +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.logger import init_logging_from_config from dlt.common.storages import FileStorage from dlt.common.schema import Schema @@ -22,9 +22,9 @@ # add test dictionary provider def TEST_DICT_CONFIG_PROVIDER(): - providers_context = Container()[ConfigProvidersListContext] + providers_context = Container()[ConfigProvidersContext] try: - return providers_context.get_provider(DictionaryProvider.NAME) + return providers_context[DictionaryProvider.NAME] except KeyError: provider = DictionaryProvider() providers_context.add_provider(provider) From 8b4d6fdb07fbadcf94666800d1e2c6984e087594 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 28 Oct 2022 11:30:51 +0200 Subject: [PATCH 51/66] ports pipeline v1 util methods --- experiments/pipeline/pipeline.py | 46 ++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py index 5ffe84dba0..64394cc12d 100644 --- a/experiments/pipeline/pipeline.py +++ b/experiments/pipeline/pipeline.py @@ -11,6 +11,7 @@ from dlt.common.runners.runnable import Runnable from dlt.common.schema.typing import TColumnSchema, TWriteDisposition from dlt.common.source import DLT_METADATA_FIELD, TResolvableDataItem, with_table_name +from dlt.common.storages.load_storage import LoadStorage from dlt.common.typing import DictStrAny, StrAny, TFun, TSecretValue, TAny from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner @@ -23,14 +24,15 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import is_interactive from dlt.extract.extract import ExtractorStorage, extract +from dlt.load.job_client_impl import SqlJobClientBase from dlt.normalize import Normalize -from dlt.load.client_base import SqlClientBase +from dlt.load.sql_client import SqlClientBase from dlt.load.configuration import LoaderConfiguration from dlt.load import Load from dlt.normalize.configuration import NormalizeConfiguration -from experiments.pipeline.exceptions import PipelineConfigMissing, MissingDependencyException, PipelineStepFailed +from experiments.pipeline.exceptions import PipelineConfigMissing, MissingDependencyException, PipelineStepFailed, SqlClientNotAvailable from dlt.extract.sources import DltResource, DltSource, TTableSchemaTemplate from experiments.pipeline.typing import TPipelineStep, TPipelineState from experiments.pipeline.configuration import StateInjectableContext @@ -337,6 +339,46 @@ def default_schema(self) -> Schema: def last_run_exception(self) -> BaseException: return runner.LAST_RUN_EXCEPTION + def list_extracted_resources(self) -> Sequence[str]: + return self._get_normalize_storage().list_files_to_normalize_sorted() + + def list_normalized_load_packages(self) -> Sequence[str]: + return self._get_load_storage().load_storage.list_packages() + + def list_completed_load_packages(self) -> Sequence[str]: + return self._get_load_storage().load_storage.list_completed_packages() + + def list_failed_jobs_in_package(self, load_id: str) -> Sequence[Tuple[str, str]]: + storage = self._get_load_storage() + failed_jobs: List[Tuple[str, str]] = [] + for file in storage.load_storage.list_completed_failed_jobs(load_id): + if not file.endswith(".exception"): + try: + failed_message = storage.storage.load(file + ".exception") + except FileNotFoundError: + failed_message = None + failed_jobs.append((file, failed_message)) + return failed_jobs + + def sync_schema(self, schema_name: str = None) -> None: + with self._get_destination_client(self.schemas[schema_name]) as client: + client.initialize_storage(wipe_data=self.always_drop_pipeline) + client.update_storage_schema() + + def sql_client(self, schema_name: str = None) -> SqlClientBase[Any]: + with self._get_destination_client(self.schemas[schema_name]) as client: + if isinstance(client, SqlJobClientBase): + return client.sql_client + else: + raise SqlClientNotAvailable(self.destination.name()) + + def _get_normalize_storage(self) -> NormalizeStorage: + return NormalizeStorage(True, self._normalize_storage_config) + + def _get_load_storage(self) -> LoadStorage: + caps = self._get_destination_capabilities() + return LoadStorage(True, caps.preferred_loader_file_format, caps.supported_loader_file_formats, self._load_storage_config) + @with_state_sync def _configure(self, pipeline_name: str, working_dir: str, import_schema_path: str, export_schema_path: str, always_drop_pipeline: bool) -> None: self.pipeline_name = pipeline_name From 3fa7d2193f91cac20f9bf0a33d346f951a89f3d0 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 28 Oct 2022 11:31:19 +0200 Subject: [PATCH 52/66] renames modules holding client implementations in load --- dlt/helpers/pandas.py | 2 +- dlt/load/bigquery/bigquery.py | 4 ++-- dlt/load/{client_base_impl.py => job_client_impl.py} | 2 +- dlt/load/load.py | 2 +- dlt/load/redshift/redshift.py | 4 ++-- dlt/load/{client_base.py => sql_client.py} | 0 dlt/pipeline/pipeline.py | 2 +- tests/load/test_client.py | 4 ++-- tests/load/test_dummy_client.py | 2 +- tests/load/utils.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) rename dlt/load/{client_base_impl.py => job_client_impl.py} (98%) rename dlt/load/{client_base.py => sql_client.py} (100%) diff --git a/dlt/helpers/pandas.py b/dlt/helpers/pandas.py index 3d9dcb1c59..1fb5249929 100644 --- a/dlt/helpers/pandas.py +++ b/dlt/helpers/pandas.py @@ -1,7 +1,7 @@ from typing import Any from dlt.pipeline.exceptions import MissingDependencyException -from dlt.load.client_base import SqlClientBase +from dlt.load.sql_client import SqlClientBase try: import pandas as pd diff --git a/dlt/load/bigquery/bigquery.py b/dlt/load/bigquery/bigquery.py index 0b228192e4..7eee8ac7c7 100644 --- a/dlt/load/bigquery/bigquery.py +++ b/dlt/load/bigquery/bigquery.py @@ -19,8 +19,8 @@ from dlt.common.schema import TColumnSchema, TDataType, Schema, TTableSchemaColumns from dlt.load.typing import DBCursor -from dlt.load.client_base import SqlClientBase -from dlt.load.client_base_impl import SqlJobClientBase +from dlt.load.sql_client import SqlClientBase +from dlt.load.job_client_impl import SqlJobClientBase from dlt.load.exceptions import LoadClientSchemaWillNotUpdate, LoadJobNotExistsException, LoadJobServerTerminalException, LoadUnknownTableException from dlt.load.bigquery import capabilities diff --git a/dlt/load/client_base_impl.py b/dlt/load/job_client_impl.py similarity index 98% rename from dlt/load/client_base_impl.py rename to dlt/load/job_client_impl.py index c681a61992..5085c1f368 100644 --- a/dlt/load/client_base_impl.py +++ b/dlt/load/job_client_impl.py @@ -8,7 +8,7 @@ from dlt.common.destination import DestinationClientConfiguration, TLoadJobStatus, LoadJob, JobClientBase from dlt.load.typing import TNativeConn -from dlt.load.client_base import SqlClientBase +from dlt.load.sql_client import SqlClientBase from dlt.load.exceptions import LoadClientSchemaVersionCorrupted diff --git a/dlt/load/load.py b/dlt/load/load.py index a867915bd8..2846613f36 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -14,7 +14,7 @@ from dlt.common.telemetry import get_logging_extras, set_gauge_all_labels from dlt.common.destination import JobClientBase, DestinationReference, LoadJob, TLoadJobStatus, DestinationClientConfiguration -from dlt.load.client_base_impl import LoadEmptyJob +from dlt.load.job_client_impl import LoadEmptyJob from dlt.load.configuration import LoaderConfiguration from dlt.load.exceptions import LoadClientTerminalException, LoadClientTransientException, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, LoadJobNotExistsException, LoadUnknownTableException diff --git a/dlt/load/redshift/redshift.py b/dlt/load/redshift/redshift.py index be43841218..ad677baaac 100644 --- a/dlt/load/redshift/redshift.py +++ b/dlt/load/redshift/redshift.py @@ -20,8 +20,8 @@ from dlt.load.exceptions import LoadClientSchemaWillNotUpdate, LoadClientTerminalInnerException, LoadClientTransientInnerException from dlt.load.typing import DBCursor -from dlt.load.client_base import SqlClientBase -from dlt.load.client_base_impl import SqlJobClientBase, LoadEmptyJob +from dlt.load.sql_client import SqlClientBase +from dlt.load.job_client_impl import SqlJobClientBase, LoadEmptyJob from dlt.load.redshift import capabilities from dlt.load.redshift.configuration import RedshiftClientConfiguration diff --git a/dlt/load/client_base.py b/dlt/load/sql_client.py similarity index 100% rename from dlt/load/client_base.py rename to dlt/load/sql_client.py diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 27cc83acef..34f89d38bb 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -20,7 +20,7 @@ from dlt.common.source import DLT_METADATA_FIELD, TItem, with_table_name from dlt.extract.extractor_storage import ExtractorStorageBase -from dlt.load.client_base import SqlClientBase, SqlJobClientBase +from dlt.load.sql_client import SqlClientBase, SqlJobClientBase from dlt.normalize.configuration import configuration as normalize_configuration from dlt.load.configuration import configuration as loader_configuration from dlt.normalize import Normalize diff --git a/tests/load/test_client.py b/tests/load/test_client.py index 699d0e2c6f..b90c9b6f64 100644 --- a/tests/load/test_client.py +++ b/tests/load/test_client.py @@ -10,8 +10,8 @@ from dlt.common.schema import TTableSchemaColumns from dlt.common.utils import uniq_id -from dlt.load.client_base import DBCursor -from dlt.load.client_base_impl import SqlJobClientBase +from dlt.load.sql_client import DBCursor +from dlt.load.job_client_impl import SqlJobClientBase from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.common.utils import load_json_case diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 16194212c3..405bab254e 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -14,7 +14,7 @@ from dlt.common.destination import DestinationReference, LoadJob from dlt.load import Load -from dlt.load.client_base_impl import LoadEmptyJob +from dlt.load.job_client_impl import LoadEmptyJob from dlt.load import dummy from dlt.load.dummy import dummy as dummy_impl diff --git a/tests/load/utils.py b/tests/load/utils.py index 80d39def16..15ca425daf 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -16,7 +16,7 @@ from dlt.common.utils import uniq_id from dlt.load import Load -from dlt.load.client_base_impl import SqlJobClientBase +from dlt.load.job_client_impl import SqlJobClientBase TABLE_UPDATE: List[TColumnSchema] = [ From 08c10dbff2456161f6bb377c3abaf73ce07c1468 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 28 Oct 2022 11:32:44 +0200 Subject: [PATCH 53/66] removes pipeline v1 --- dlt/pipeline/__init__.py | 4 - dlt/pipeline/exceptions.py | 72 ------- dlt/pipeline/pipeline.py | 426 ------------------------------------- dlt/pipeline/typing.py | 98 --------- 4 files changed, 600 deletions(-) delete mode 100644 dlt/pipeline/__init__.py delete mode 100644 dlt/pipeline/exceptions.py delete mode 100644 dlt/pipeline/pipeline.py delete mode 100644 dlt/pipeline/typing.py diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py deleted file mode 100644 index 661ddc5ec9..0000000000 --- a/dlt/pipeline/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from dlt.common.schema import Schema # noqa: F401 -from dlt.pipeline.pipeline import Pipeline # noqa: F401 -from dlt.pipeline.typing import GCPPipelineCredentials, PostgresPipelineCredentials # noqa: F401 -from dlt.pipeline.exceptions import CannotRestorePipelineException # noqa: F401 diff --git a/dlt/pipeline/exceptions.py b/dlt/pipeline/exceptions.py deleted file mode 100644 index 88ffe4c3c4..0000000000 --- a/dlt/pipeline/exceptions.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Any, Sequence -from dlt.common.exceptions import DltException -from dlt.common.telemetry import TRunMetrics -from dlt.pipeline.typing import TPipelineStage - - -class PipelineException(DltException): - pass - - -class MissingDependencyException(PipelineException): - def __init__(self, caller: str, dependencies: Sequence[str], appendix: str = "") -> None: - self.caller = caller - self.dependencies = dependencies - super().__init__(self._get_msg(appendix)) - - def _get_msg(self, appendix: str) -> str: - msg = f""" -You must install additional dependencies to run {self.caller}. If you use pip you may do the following: - -{self._to_pip_install()} -""" - if appendix: - msg = msg + "\n" + appendix - return msg - - def _to_pip_install(self) -> str: - return "\n".join([f"pip install {d}" for d in self.dependencies]) - - -class NoPipelineException(PipelineException): - def __init__(self) -> None: - super().__init__("Please create or restore pipeline before using this function") - - -class InvalidPipelineContextException(PipelineException): - def __init__(self) -> None: - super().__init__("There may be just one active pipeline in single python process. You may have switch between pipelines by restoring pipeline just before using load method") - - -class CannotRestorePipelineException(PipelineException): - def __init__(self, reason: str) -> None: - super().__init__(reason) - - -class PipelineBackupNotFound(PipelineException): - def __init__(self, method: str) -> None: - self.method = method - super().__init__(f"Backup not found for method {method}") - - -class SqlClientNotAvailable(PipelineException): - def __init__(self, destination_name: str) -> None: - super().__init__(f"SQL Client not available for {destination_name}") - - -class InvalidIteratorException(PipelineException): - def __init__(self, iterator: Any) -> None: - super().__init__(f"Unsupported source iterator or iterable type: {type(iterator).__name__}") - - -class InvalidItemException(PipelineException): - def __init__(self, item: Any) -> None: - super().__init__(f"Source yielded unsupported item type: {type(item).__name}. Only dictionaries, sequences and deferred items allowed.") - - -class PipelineStepFailed(PipelineException): - def __init__(self, stage: TPipelineStage, exception: BaseException, run_metrics: TRunMetrics) -> None: - self.stage = stage - self.exception = exception - self.run_metrics = run_metrics - super().__init__(f"Pipeline execution failed at stage {stage} with exception:\n\n{type(exception)}\n{exception}") diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py deleted file mode 100644 index 34f89d38bb..0000000000 --- a/dlt/pipeline/pipeline.py +++ /dev/null @@ -1,426 +0,0 @@ - -from contextlib import contextmanager -from copy import deepcopy -import yaml -from collections import abc -from dataclasses import asdict as dtc_asdict -import tempfile -import os.path -from typing import Any, Iterator, List, Sequence, Tuple, Callable -from prometheus_client import REGISTRY - -from dlt.common import json, sleep, signals, logger -from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner -from dlt.common.configuration import resolve_configuration -from dlt.common.configuration.specs import PoolRunnerConfiguration -from dlt.common.storages import FileStorage -from dlt.common.schema import Schema -from dlt.common.typing import DictStrAny, StrAny -from dlt.common.utils import uniq_id, is_interactive -from dlt.common.source import DLT_METADATA_FIELD, TItem, with_table_name - -from dlt.extract.extractor_storage import ExtractorStorageBase -from dlt.load.sql_client import SqlClientBase, SqlJobClientBase -from dlt.normalize.configuration import configuration as normalize_configuration -from dlt.load.configuration import configuration as loader_configuration -from dlt.normalize import Normalize -from dlt.load import Load -from dlt.pipeline.exceptions import MissingDependencyException, NoPipelineException, PipelineStepFailed, CannotRestorePipelineException, SqlClientNotAvailable -from dlt.pipeline.typing import PipelineCredentials - - -class Pipeline: - def __init__(self, pipeline_name: str, log_level: str = "INFO") -> None: - self.pipeline_name = pipeline_name - self.root_path: str = None - self.export_schema_path: str = None - self.import_schema_path: str = None - self.root_storage: FileStorage = None - self.credentials: PipelineCredentials = None - self.extractor_storage: ExtractorStorageBase = None - self.default_schema_name: str = None - self.state: DictStrAny = {} - - # addresses of pipeline components to be verified before they are run - self._normalize_instance: Normalize = None - self._loader_instance: Load = None - - # patch config and initialize pipeline - self.C = resolve_configuration(PoolRunnerConfiguration(), initial_value={ - "PIPELINE_NAME": pipeline_name, - "LOG_LEVEL": log_level, - "POOL_TYPE": "None", - "IS_SINGLE_RUN": True, - "WAIT_RUNS": 0, - "EXIT_ON_EXCEPTION": True, - }) - initialize_runner(self.C) - - def create_pipeline( - self, - credentials: PipelineCredentials, - working_dir: str = None, - schema: Schema = None, - import_schema_path: str = None, - export_schema_path: str = None - ) -> None: - # initialize root storage - if not working_dir: - working_dir = tempfile.mkdtemp() - self.root_storage = FileStorage(working_dir, makedirs=True) - self.export_schema_path = export_schema_path - self.import_schema_path = import_schema_path - - # check if directory contains restorable pipeline - try: - self._restore_state() - # wipe out the old pipeline - self.root_storage.delete_folder("", recursively=True) - self.root_storage.create_folder("") - except FileNotFoundError: - pass - - self.root_path = self.root_storage.storage_path - self.credentials = credentials - self._load_modules() - self.extractor_storage = ExtractorStorageBase( - "1.0.0", - True, - FileStorage(os.path.join(self.root_path, "extractor"), makedirs=True), - self._normalize_instance.normalize_storage) - # create new schema if no default supplied - if schema is None: - # try to load schema, that will also import it - schema_name = self.pipeline_name - try: - schema = self._normalize_instance.schema_storage.load_schema(schema_name) - except FileNotFoundError: - # create new empty schema - schema = Schema(schema_name) - # initialize empty state, this must be last operation when creating pipeline so restore reads only fully created ones - with self._managed_state(): - self.state = { - # "default_schema_name": default_schema_name, - "pipeline_name": self.pipeline_name, - # TODO: must come from resolved configuration - "loader_client_type": credentials.CLIENT_TYPE, - # TODO: must take schema prefix from resolved configuration - "loader_schema_prefix": credentials.default_dataset - } - # persist schema with the pipeline - self.set_default_schema(schema) - - def restore_pipeline( - self, - credentials: PipelineCredentials, - working_dir: str, - import_schema_path: str = None, - export_schema_path: str = None - ) -> None: - try: - # do not create extractor dir - it must exist - self.root_storage = FileStorage(working_dir, makedirs=False) - # restore state, this must be a first operation when restoring pipeline - try: - self._restore_state() - except FileNotFoundError: - raise CannotRestorePipelineException(f"Cannot find a valid pipeline in {working_dir}") - restored_name = self.state["pipeline_name"] - if self.pipeline_name != restored_name: - raise CannotRestorePipelineException(f"Expected pipeline {self.pipeline_name}, found {restored_name} pipeline instead") - self.default_schema_name = self.state["default_schema_name"] - if not credentials.default_dataset: - credentials.default_dataset = self.state["loader_schema_prefix"] - self.root_path = self.root_storage.storage_path - self.credentials = credentials - self.export_schema_path = export_schema_path - self.import_schema_path = import_schema_path - self._load_modules() - # schema must exist - try: - self.get_default_schema() - except (FileNotFoundError): - raise CannotRestorePipelineException(f"Default schema with name {self.default_schema_name} not found") - self.extractor_storage = ExtractorStorageBase( - "1.0.0", - True, - FileStorage(os.path.join(self.root_path, "extractor"), makedirs=False), - self._normalize_instance.normalize_storage - ) - except CannotRestorePipelineException: - raise - - def extract(self, items: Iterator[TItem], schema_name: str = None, table_name: str = None) -> None: - # check if iterator or iterable is supported - # if isinstance(items, str) or isinstance(items, dict) or not - # TODO: check if schema exists - with self._managed_state(): - default_table_name = table_name or self.pipeline_name - # TODO: this is not very effective - we consume iterator right away, better implementation needed where we stream iterator to files directly - all_items: List[DictStrAny] = [] - for item in items: - # dispatch items by type - if callable(item): - item = item() - if isinstance(item, dict): - all_items.append(item) - elif isinstance(item, abc.Sequence): - all_items.extend(item) - # react to CTRL-C and shutdowns from controllers - signals.raise_if_signalled() - - try: - self._extract_iterator(default_table_name, all_items) - except Exception: - raise PipelineStepFailed("extract", self.last_run_exception, runner.LAST_RUN_METRICS) - - def normalize(self, workers: int = 1, max_events_in_chunk: int = 100000) -> int: - if is_interactive() and workers > 1: - raise NotImplementedError("Do not use workers in interactive mode ie. in notebook") - self._verify_normalize_instance() - # set runtime parameters - self._normalize_instance.config.workers = workers - # switch to thread pool for single worker - self._normalize_instance.config.pool_type = "thread" if workers == 1 else "process" - try: - ec = runner.run_pool(self._normalize_instance.config, self._normalize_instance) - # in any other case we raise if runner exited with status failed - if runner.LAST_RUN_METRICS.has_failed: - raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) - return ec - except Exception as r_ex: - # if EXIT_ON_EXCEPTION flag is set, exception will bubble up directly - raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) from r_ex - - def load(self, max_parallel_loads: int = 20) -> int: - self._verify_loader_instance() - self._loader_instance.config.workers = max_parallel_loads - self._loader_instance.load_client_cls.CONFIG.DEFAULT_SCHEMA_NAME = self.default_schema_name # type: ignore - try: - ec = runner.run_pool(self._loader_instance.config, self._loader_instance) - # in any other case we raise if runner exited with status failed - if runner.LAST_RUN_METRICS.has_failed: - raise PipelineStepFailed("load", self.last_run_exception, runner.LAST_RUN_METRICS) - return ec - except Exception as r_ex: - # if EXIT_ON_EXCEPTION flag is set, exception will bubble up directly - raise PipelineStepFailed("load", self.last_run_exception, runner.LAST_RUN_METRICS) from r_ex - - def flush(self) -> None: - self.normalize() - self.load() - - @property - def working_dir(self) -> str: - return os.path.abspath(self.root_path) - - @property - def last_run_exception(self) -> BaseException: - return runner.LAST_RUN_EXCEPTION - - def list_extracted_loads(self) -> Sequence[str]: - self._verify_loader_instance() - return self._normalize_instance.normalize_storage.list_files_to_normalize_sorted() - - def list_normalized_loads(self) -> Sequence[str]: - self._verify_loader_instance() - return self._loader_instance.load_storage.list_packages() - - def list_completed_loads(self) -> Sequence[str]: - self._verify_loader_instance() - return self._loader_instance.load_storage.list_completed_packages() - - def list_failed_jobs(self, load_id: str) -> Sequence[Tuple[str, str]]: - self._verify_loader_instance() - failed_jobs: List[Tuple[str, str]] = [] - for file in self._loader_instance.load_storage.list_completed_failed_jobs(load_id): - if not file.endswith(".exception"): - try: - failed_message = self._loader_instance.load_storage.storage.load(file + ".exception") - except FileNotFoundError: - failed_message = None - failed_jobs.append((file, failed_message)) - return failed_jobs - - def get_default_schema(self) -> Schema: - self._verify_normalize_instance() - return self._normalize_instance.schema_storage.load_schema(self.default_schema_name) - - def set_default_schema(self, new_schema: Schema) -> None: - if self.default_schema_name: - # delete old schema - try: - self._normalize_instance.schema_storage.remove_schema(self.default_schema_name) - self.default_schema_name = None - except FileNotFoundError: - pass - # save new schema - self._normalize_instance.schema_storage.save_schema(new_schema) - self.default_schema_name = new_schema.name - with self._managed_state(): - self.state["default_schema_name"] = self.default_schema_name - - def add_schema(self, aux_schema: Schema) -> None: - self._normalize_instance.schema_storage.save_schema(aux_schema) - - def get_schema(self, name: str) -> Schema: - return self._normalize_instance.schema_storage.load_schema(name) - - def remove_schema(self, name: str) -> None: - self._normalize_instance.schema_storage.remove_schema(name) - - def sync_schema(self) -> None: - self._verify_loader_instance() - schema = self._normalize_instance.schema_storage.load_schema(self.default_schema_name) - with self._loader_instance.load_client_cls(schema) as client: - client.initialize_storage() - client.update_storage_schema() - - def sql_client(self, schema_name: str = None) -> SqlClientBase[Any]: - self._verify_loader_instance() - schema = self._normalize_instance.schema_storage.load_schema(schema_name or self.default_schema_name) - with self._loader_instance.load_client_cls(schema) as c: - if isinstance(c, SqlJobClientBase): - return c.sql_client - else: - raise SqlClientNotAvailable(self._loader_instance.config.client_type) - - def run_in_pool(self, run_f: Callable[..., Any]) -> int: - # internal runners should work in single mode - self._loader_instance.config.is_single_run = True - self._loader_instance.config.exit_on_exception = True - self._normalize_instance.config.is_single_run = True - self._normalize_instance.config.exit_on_exception = True - - def _run(_: Any) -> TRunMetrics: - rv = run_f() - if isinstance(rv, TRunMetrics): - return rv - if isinstance(rv, int): - pending = rv - else: - pending = 1 - return TRunMetrics(False, False, int(pending)) - - # run the fun - ec = runner.run_pool(self.C, _run) - # ec > 0 - signalled - # -1 - runner was not able to start - - if runner.LAST_RUN_METRICS is not None and runner.LAST_RUN_METRICS.has_failed: - raise self.last_run_exception - return ec - - - def _configure_normalize(self) -> None: - # create normalize config - normalize_initial = { - "NORMALIZE_VOLUME_PATH": os.path.join(self.root_path, "normalize"), - "SCHEMA_VOLUME_PATH": os.path.join(self.root_path, "schemas"), - "EXPORT_SCHEMA_PATH": os.path.abspath(self.export_schema_path) if self.export_schema_path else None, - "IMPORT_SCHEMA_PATH": os.path.abspath(self.import_schema_path) if self.import_schema_path else None, - "LOADER_FILE_FORMAT": self._loader_instance.load_client_cls.capabilities()["preferred_loader_file_format"], - "ADD_EVENT_JSON": False - } - normalize_initial.update(self._configure_runner()) - C = normalize_configuration(initial_values=normalize_initial) - # shares schema storage with the pipeline so we do not need to install - self._normalize_instance = Normalize(C) - - def _configure_load(self) -> None: - # use credentials to populate loader client config, it includes also client type - loader_client_initial = dtc_asdict(self.credentials) - loader_client_initial["DEFAULT_SCHEMA_NAME"] = self.default_schema_name - # but client type must be passed to loader config - loader_initial = {"CLIENT_TYPE": loader_client_initial["CLIENT_TYPE"]} - loader_initial.update(self._configure_runner()) - loader_initial["DELETE_COMPLETED_JOBS"] = True - try: - C = loader_configuration(initial_values=loader_initial) - self._loader_instance = Load(C, REGISTRY, client_initial_values=loader_client_initial, is_storage_owner=True) - except ImportError: - raise MissingDependencyException( - f"{self.credentials.CLIENT_TYPE} destination", - [f"python-dlt[{self.credentials.CLIENT_TYPE}]"], - "Dependencies for specific destination are available as extras of python-dlt" - ) - - def _verify_loader_instance(self) -> None: - if self._loader_instance is None: - raise NoPipelineException() - - def _verify_normalize_instance(self) -> None: - if self._loader_instance is None: - raise NoPipelineException() - - def _configure_runner(self) -> StrAny: - return { - "PIPELINE_NAME": self.pipeline_name, - "IS_SINGLE_RUN": True, - "WAIT_RUNS": 0, - "EXIT_ON_EXCEPTION": True, - "LOAD_VOLUME_PATH": os.path.join(self.root_path, "normalized") - } - - def _load_modules(self) -> None: - # configure loader - self._configure_load() - # configure normalize - self._configure_normalize() - - def _extract_iterator(self, default_table_name: str, items: Sequence[DictStrAny]) -> None: - try: - for idx, i in enumerate(items): - if not isinstance(i, dict): - # TODO: convert non dict types into dict - items[idx] = i = {"v": i} - if DLT_METADATA_FIELD not in i or i.get(DLT_METADATA_FIELD, None) is None: - # set default table name - with_table_name(i, default_table_name) - - load_id = uniq_id() - self.extractor_storage.save_json(f"{load_id}.json", items) - self.extractor_storage.commit_events( - self.default_schema_name, - self.extractor_storage.storage.make_full_path(f"{load_id}.json"), - default_table_name, - len(items), - load_id - ) - - runner.LAST_RUN_METRICS = TRunMetrics(was_idle=False, has_failed=False, pending_items=0) - except Exception as ex: - logger.exception("extracting iterator failed") - runner.LAST_RUN_METRICS = TRunMetrics(was_idle=False, has_failed=True, pending_items=0) - runner.LAST_RUN_EXCEPTION = ex - raise - - @contextmanager - def _managed_state(self) -> Iterator[None]: - backup_state = deepcopy(self.state) - try: - yield - except Exception: - # restore old state - self.state.clear() - self.state.update(backup_state) - raise - else: - # persist old state - self.root_storage.save("state.json", json.dumps(self.state)) - - def _restore_state(self) -> None: - self.state.clear() - restored_state: DictStrAny = json.loads(self.root_storage.load("state.json")) - self.state.update(restored_state) - - @staticmethod - def save_schema_to_file(file_name: str, schema: Schema, remove_defaults: bool = True) -> None: - with open(file_name, "w", encoding="utf-8") as f: - f.write(schema.to_pretty_yaml(remove_defaults=remove_defaults)) - - @staticmethod - def load_schema_from_file(file_name: str) -> Schema: - with open(file_name, "r", encoding="utf-8") as f: - schema_dict: DictStrAny = yaml.safe_load(f) - return Schema.from_dict(schema_dict) diff --git a/dlt/pipeline/typing.py b/dlt/pipeline/typing.py deleted file mode 100644 index f269a0c17e..0000000000 --- a/dlt/pipeline/typing.py +++ /dev/null @@ -1,98 +0,0 @@ - -from typing import Literal, Type, Any -from dataclasses import dataclass, fields as dtc_fields -from dlt.common import json - -from dlt.common.typing import StrAny, TSecretValue - -TLoaderType = Literal["bigquery", "redshift", "dummy"] -TPipelineStage = Literal["extract", "normalize", "load"] - -# extractor generator yields functions that returns list of items of the type (table) when called -# this allows generator to implement retry logic -# TExtractorItem = Callable[[], Iterator[StrAny]] -# # extractor generator yields tuples: (type of the item (table name), function defined above) -# TExtractorItemWithTable = Tuple[str, TExtractorItem] -# TExtractorGenerator = Callable[[DictStrAny], Iterator[TExtractorItemWithTable]] - - -@dataclass -class PipelineCredentials: - CLIENT_TYPE: TLoaderType - - @property - def default_dataset(self) -> str: - pass - - @default_dataset.setter - def default_dataset(self, new_value: str) -> None: - pass - -@dataclass -class GCPPipelineCredentials(PipelineCredentials): - PROJECT_ID: str = None - DEFAULT_DATASET: str = None - CLIENT_EMAIL: str = None - PRIVATE_KEY: TSecretValue = None - LOCATION: str = "US" - CRED_TYPE: str = "service_account" - TOKEN_URI: str = "https://oauth2.googleapis.com/token" - HTTP_TIMEOUT: float = 15.0 - RETRY_DEADLINE: float = 600 - - @property - def default_dataset(self) -> str: - return self.DEFAULT_DATASET - - @default_dataset.setter - def default_dataset(self, new_value: str) -> None: - self.DEFAULT_DATASET = new_value - - @classmethod - def from_services_dict(cls, services: StrAny, dataset_prefix: str, location: str = "US") -> "GCPPipelineCredentials": - assert dataset_prefix is not None - return cls("bigquery", services["project_id"], dataset_prefix, services["client_email"], services["private_key"], location or cls.LOCATION) - - @classmethod - def from_services_file(cls, services_path: str, dataset_prefix: str, location: str = "US") -> "GCPPipelineCredentials": - with open(services_path, "r", encoding="utf-8") as f: - services = json.load(f) - return GCPPipelineCredentials.from_services_dict(services, dataset_prefix, location) - - @classmethod - def default_credentials(cls, dataset_prefix: str, project_id: str = None, location: str = None) -> "GCPPipelineCredentials": - return cls("bigquery", project_id, dataset_prefix, None, None, location or cls.LOCATION) - - -@dataclass -class PostgresPipelineCredentials(PipelineCredentials): - DBNAME: str = None - DEFAULT_DATASET: str = None - USER: str = None - HOST: str = None - PASSWORD: TSecretValue = None - PORT: int = 5439 - CONNECT_TIMEOUT: int = 15 - - @property - def default_dataset(self) -> str: - return self.DEFAULT_DATASET - - @default_dataset.setter - def default_dataset(self, new_value: str) -> None: - self.DEFAULT_DATASET = new_value - - -def credentials_from_dict(credentials: StrAny) -> PipelineCredentials: - - def ignore_unknown_props(typ_: Type[Any], props: StrAny) -> StrAny: - fields = {f.name: f for f in dtc_fields(typ_)} - return {k:v for k,v in props.items() if k in fields} - - client_type = credentials.get("CLIENT_TYPE") - if client_type == "bigquery": - return GCPPipelineCredentials(**ignore_unknown_props(GCPPipelineCredentials, credentials)) - elif client_type == "redshift": - return PostgresPipelineCredentials(**ignore_unknown_props(PostgresPipelineCredentials, credentials)) - else: - raise ValueError(f"CLIENT_TYPE: {client_type}") From 0290aa34cf2e9706e6b47ff2fd37d1ffbafaf108 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 28 Oct 2022 11:33:55 +0200 Subject: [PATCH 54/66] moves pipeline v2 in --- {experiments => dlt}/pipeline/README.md | 0 {experiments => dlt}/pipeline/__init__.py | 0 {experiments => dlt}/pipeline/configuration.py | 0 {experiments => dlt}/pipeline/decorators.py | 0 {experiments => dlt}/pipeline/exceptions.py | 0 {experiments => dlt}/pipeline/pipeline.py | 0 {experiments => dlt}/pipeline/typing.py | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename {experiments => dlt}/pipeline/README.md (100%) rename {experiments => dlt}/pipeline/__init__.py (100%) rename {experiments => dlt}/pipeline/configuration.py (100%) rename {experiments => dlt}/pipeline/decorators.py (100%) rename {experiments => dlt}/pipeline/exceptions.py (100%) rename {experiments => dlt}/pipeline/pipeline.py (100%) rename {experiments => dlt}/pipeline/typing.py (100%) diff --git a/experiments/pipeline/README.md b/dlt/pipeline/README.md similarity index 100% rename from experiments/pipeline/README.md rename to dlt/pipeline/README.md diff --git a/experiments/pipeline/__init__.py b/dlt/pipeline/__init__.py similarity index 100% rename from experiments/pipeline/__init__.py rename to dlt/pipeline/__init__.py diff --git a/experiments/pipeline/configuration.py b/dlt/pipeline/configuration.py similarity index 100% rename from experiments/pipeline/configuration.py rename to dlt/pipeline/configuration.py diff --git a/experiments/pipeline/decorators.py b/dlt/pipeline/decorators.py similarity index 100% rename from experiments/pipeline/decorators.py rename to dlt/pipeline/decorators.py diff --git a/experiments/pipeline/exceptions.py b/dlt/pipeline/exceptions.py similarity index 100% rename from experiments/pipeline/exceptions.py rename to dlt/pipeline/exceptions.py diff --git a/experiments/pipeline/pipeline.py b/dlt/pipeline/pipeline.py similarity index 100% rename from experiments/pipeline/pipeline.py rename to dlt/pipeline/pipeline.py diff --git a/experiments/pipeline/typing.py b/dlt/pipeline/typing.py similarity index 100% rename from experiments/pipeline/typing.py rename to dlt/pipeline/typing.py From d16b430fc191e8edde77d6f08279c95b7f60ed51 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 28 Oct 2022 11:45:12 +0200 Subject: [PATCH 55/66] fixes pipeline imports --- dlt/pipeline/__init__.py | 6 +++--- dlt/pipeline/configuration.py | 2 +- dlt/pipeline/exceptions.py | 2 +- dlt/pipeline/pipeline.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 53627f01ca..9408db4037 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -6,9 +6,9 @@ from dlt.common.destination import DestinationReference, resolve_destination_reference from dlt.common.pipeline import PipelineContext, get_default_working_dir -from experiments.pipeline.configuration import PipelineConfiguration -from experiments.pipeline.pipeline import Pipeline -from experiments.pipeline.decorators import source, resource +from dlt.pipeline.configuration import PipelineConfiguration +from dlt.pipeline.pipeline import Pipeline +from dlt.pipeline.decorators import source, resource # @overload diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 2f374524ff..27f63f4f7d 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -6,7 +6,7 @@ from dlt.common.typing import TSecretValue from dlt.common.utils import uniq_id -from experiments.pipeline.typing import TPipelineState +from dlt.pipeline.typing import TPipelineState @configspec diff --git a/dlt/pipeline/exceptions.py b/dlt/pipeline/exceptions.py index 7cb58e59fa..c8598f8117 100644 --- a/dlt/pipeline/exceptions.py +++ b/dlt/pipeline/exceptions.py @@ -1,7 +1,7 @@ from typing import Any, Sequence from dlt.common.exceptions import DltException, ArgumentsOverloadException from dlt.common.telemetry import TRunMetrics -from experiments.pipeline.typing import TPipelineStep +from dlt.pipeline.typing import TPipelineStep class PipelineException(DltException): diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 64394cc12d..a03617c507 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -32,10 +32,10 @@ from dlt.load import Load from dlt.normalize.configuration import NormalizeConfiguration -from experiments.pipeline.exceptions import PipelineConfigMissing, MissingDependencyException, PipelineStepFailed, SqlClientNotAvailable +from dlt.pipeline.exceptions import PipelineConfigMissing, MissingDependencyException, PipelineStepFailed, SqlClientNotAvailable from dlt.extract.sources import DltResource, DltSource, TTableSchemaTemplate -from experiments.pipeline.typing import TPipelineStep, TPipelineState -from experiments.pipeline.configuration import StateInjectableContext +from dlt.pipeline.typing import TPipelineStep, TPipelineState +from dlt.pipeline.configuration import StateInjectableContext class Pipeline: From b2a9bde197ba740839bcaa36d357d061dd35a3a7 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 28 Oct 2022 17:10:53 +0200 Subject: [PATCH 56/66] fixes typing errors, adds overloads --- dlt/cli/dlt.py | 44 +++--- dlt/common/configuration/inject.py | 6 +- dlt/common/configuration/resolve.py | 5 +- .../specs/load_volume_configuration.py | 6 + dlt/common/destination.py | 16 +- dlt/common/normalizers/json/relational.py | 4 +- dlt/common/source.py | 42 +++--- dlt/common/typing.py | 8 +- dlt/dbt_runner/configuration.py | 4 +- dlt/extract/sources.py | 19 +-- dlt/load/configuration.py | 14 +- dlt/normalize/configuration.py | 15 ++ dlt/pipeline/__init__.py | 6 +- dlt/pipeline/decorators.py | 139 ++++++++++++++++-- dlt/pipeline/pipeline.py | 120 ++++++++------- 15 files changed, 300 insertions(+), 148 deletions(-) diff --git a/dlt/cli/dlt.py b/dlt/cli/dlt.py index 43571860c2..774e78d1ba 100644 --- a/dlt/cli/dlt.py +++ b/dlt/cli/dlt.py @@ -8,7 +8,7 @@ from dlt.common.schema import Schema from dlt.common.typing import DictStrAny -from dlt.pipeline import Pipeline, PostgresPipelineCredentials +from dlt.pipeline import pipeline def add_pool_cli_arguments(parser: argparse.ArgumentParser) -> None: @@ -26,33 +26,35 @@ def add_pool_cli_arguments(parser: argparse.ArgumentParser) -> None: def main() -> None: parser = argparse.ArgumentParser(description="Runs various DLT modules", formatter_class=argparse.ArgumentDefaultsHelpFormatter) subparsers = parser.add_subparsers(dest="command") - normalize = subparsers.add_parser("normalize", help="Runs normalize") - add_pool_cli_arguments(normalize) - load = subparsers.add_parser("load", help="Runs loader") - add_pool_cli_arguments(load) + + # normalize = subparsers.add_parser("normalize", help="Runs normalize") + # add_pool_cli_arguments(normalize) + # load = subparsers.add_parser("load", help="Runs loader") + # add_pool_cli_arguments(load) + dbt = subparsers.add_parser("dbt", help="Executes dbt package") add_pool_cli_arguments(dbt) schema = subparsers.add_parser("schema", help="Shows, converts and upgrades schemas") schema.add_argument("file", help="Schema file name, in yaml or json format, will autodetect based on extension") schema.add_argument("--format", choices=["json", "yaml"], default="yaml", help="Display schema in this format") schema.add_argument("--remove-defaults", action="store_true", help="Does not show default hint values") - pipeline = subparsers.add_parser("pipeline", help="Operations on the pipelines") - pipeline.add_argument("name", help="Pipeline name") - pipeline.add_argument("workdir", help="Pipeline working directory") - pipeline.add_argument("operation", choices=["failed_loads"], default="failed_loads", help="Show failed loads for a pipeline") + pipe_cmd = subparsers.add_parser("pipeline", help="Operations on the pipelines") + pipe_cmd.add_argument("name", help="Pipeline name") + pipe_cmd.add_argument("operation", choices=["failed_loads"], default="failed_loads", help="Show failed loads for a pipeline") + pipe_cmd.add_argument("--workdir", help="Pipeline working directory", default=None) # TODO: consider using fire: https://github.com/google/python-fire # TODO: this also looks better https://click.palletsprojects.com/en/8.1.x/complex/#complex-guide args = parser.parse_args() run_f: Callable[[TRunnerArgs], None] = None - if args.command == "normalize": - from dlt.normalize.normalize import run_main as normalize_run - run_f = normalize_run - elif args.command == "load": - from dlt.load.load import run_main as loader_run - run_f = loader_run - elif args.command == "dbt": + # if args.command == "normalize": + # from dlt.normalize.normalize import run_main as normalize_run + # run_f = normalize_run + # elif args.command == "load": + # from dlt.load.load import run_main as loader_run + # run_f = loader_run + if args.command == "dbt": from dlt.dbt_runner.runner import run_main as dbt_run run_f = dbt_run elif args.command == "schema": @@ -69,12 +71,14 @@ def main() -> None: print(schema_str) exit(0) elif args.command == "pipeline": - p = Pipeline(args.name) - p.restore_pipeline(PostgresPipelineCredentials("dummy"), args.workdir) - completed_loads = p.list_completed_loads() + # from dlt.load import dummy + + p = pipeline(pipeline_name=args.name, working_dir=args.workdir, destination="dummy") + print(f"Checking pipeline {p.pipeline_name} ({args.name}) in {p.working_dir} ({args.workdir}) with state {p._state}") + completed_loads = p.list_completed_load_packages() for load_id in completed_loads: print(f"Checking failed jobs in {load_id}") - for job, failed_message in p.list_failed_jobs(load_id): + for job, failed_message in p.list_failed_jobs_in_package(load_id): print(f"JOB: {job}\nMSG: {failed_message}") exit(0) else: diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index 52eef51b3d..092f2b234c 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -15,7 +15,7 @@ _LAST_DLT_CONFIG = "_last_dlt_config" TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) # keep a registry of all the decorated functions -_FUNC_SPECS: Dict[str, Type[BaseConfiguration]] = {} +_FUNC_SPECS: Dict[int, Type[BaseConfiguration]] = {} def get_fun_spec(f: AnyFun) -> Type[BaseConfiguration]: @@ -120,8 +120,8 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: return decorator(func) -def last_config(**kwargs: Any) -> TConfiguration: - return kwargs[_LAST_DLT_CONFIG] +def last_config(**kwargs: Any) -> BaseConfiguration: + return kwargs[_LAST_DLT_CONFIG] # type: ignore def _get_spec_name_from_f(f: AnyFun) -> str: diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index ea773c2d08..5be97e4e92 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -1,7 +1,7 @@ import ast import inspect from collections.abc import Mapping as C_Mapping -from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, get_origin +from typing import Any, Dict, ContextManager, List, Optional, Sequence, Tuple, Type, TypeVar, get_origin from dlt.common import json, logger from dlt.common.typing import TSecretValue, is_optional_type, extract_inner_type @@ -64,7 +64,7 @@ def serialize_value(value: Any) -> Any: return coerce_type("text", value_dt, value) -def inject_namespace(namespace_context: ConfigNamespacesContext, merge_existing: bool = True) -> Generator[ConfigNamespacesContext, None, None]: +def inject_namespace(namespace_context: ConfigNamespacesContext, merge_existing: bool = True) -> ContextManager[ConfigNamespacesContext]: """Adds `namespace` context to container, making it injectable. Optionally merges the context already in the container with the one provided Args: @@ -91,6 +91,7 @@ def _resolve_configuration( initial_value: Any, accept_partial: bool ) -> TConfiguration: + # print(f"RESOLVING: {locals()}") # do not resolve twice if config.is_resolved(): return config diff --git a/dlt/common/configuration/specs/load_volume_configuration.py b/dlt/common/configuration/specs/load_volume_configuration.py index 3846b78bd9..c014a66d43 100644 --- a/dlt/common/configuration/specs/load_volume_configuration.py +++ b/dlt/common/configuration/specs/load_volume_configuration.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec @@ -5,3 +7,7 @@ class LoadVolumeConfiguration(BaseConfiguration): load_volume_path: str = None # path to volume where files to be loaded to analytical storage are stored delete_completed_jobs: bool = False # if set to true the folder with completed jobs will be deleted + + if TYPE_CHECKING: + def __init__(self, load_volume_path: str = None, delete_completed_jobs: bool = None) -> None: + ... diff --git a/dlt/common/destination.py b/dlt/common/destination.py index c3cf38b3f2..e3cda90278 100644 --- a/dlt/common/destination.py +++ b/dlt/common/destination.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from importlib import import_module from types import TracebackType -from typing import ClassVar, List, Optional, Literal, Type, Protocol, Union +from typing import ClassVar, List, Optional, Literal, Type, Protocol, Union, TYPE_CHECKING from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchema @@ -38,12 +38,26 @@ class DestinationClientConfiguration(BaseConfiguration): destination_name: str = None # which destination to load data to credentials: Optional[CredentialsConfiguration] + if TYPE_CHECKING: + def __init__(self, destination_name: str = None, credentials: Optional[CredentialsConfiguration] = None) -> None: + ... + @configspec(init=True) class DestinationClientDwhConfiguration(DestinationClientConfiguration): dataset_name: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix default_schema_name: Optional[str] = None # name of default schema to be used to name effective dataset to load data to + if TYPE_CHECKING: + def __init__( + self, + destination_name: str = None, + credentials: Optional[CredentialsConfiguration] = None, + dataset_name: str = None, + default_schema_name: Optional[str] = None + ) -> None: + ... + TLoadJobStatus = Literal["running", "failed", "retry", "completed"] diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index ad58de1659..1d87e48a11 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -6,7 +6,7 @@ from dlt.common.schema.utils import column_name_validator from dlt.common.utils import uniq_id, digest128 from dlt.common.normalizers.json import TNormalizedRowIterator, wrap_in_dict -from dlt.common.source import TEventDLTMeta +# from dlt.common.source import TEventDLTMeta from dlt.common.validation import validate_dict @@ -16,7 +16,7 @@ class TDataItemRow(TypedDict, total=False): class TDataItemRowRoot(TDataItemRow, total=False): _dlt_load_id: str # load id to identify records loaded together that ie. need to be processed - _dlt_meta: TEventDLTMeta # stores metadata, should never be sent to the normalizer + # _dlt_meta: TEventDLTMeta # stores metadata, should never be sent to the normalizer class TDataItemRowChild(TDataItemRow, total=False): diff --git a/dlt/common/source.py b/dlt/common/source.py index 2bda51a5bb..0d567fdf0f 100644 --- a/dlt/common/source.py +++ b/dlt/common/source.py @@ -1,14 +1,10 @@ from collections import abc from functools import wraps from typing import Any, Callable, Optional, Sequence, TypeVar, Union, TypedDict, List, Awaitable -try: - from typing_extensions import ParamSpec -except ImportError: - ParamSpec = lambda x: [x] # type: ignore from dlt.common import logger from dlt.common.time import sleep -from dlt.common.typing import StrAny, TDataItem +from dlt.common.typing import ParamSpec, TDataItem # possible types of items yielded by the source @@ -31,33 +27,33 @@ TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]] # name of dlt metadata as part of the item -DLT_METADATA_FIELD = "_dlt_meta" +# DLT_METADATA_FIELD = "_dlt_meta" -class TEventDLTMeta(TypedDict, total=False): - table_name: str # a root table in which store the event +# class TEventDLTMeta(TypedDict, total=False): +# table_name: str # a root table in which store the event -def append_dlt_meta(item: TBoundItem, name: str, value: Any) -> TBoundItem: - if isinstance(item, abc.Sequence): - for i in item: - i.setdefault(DLT_METADATA_FIELD, {})[name] = value - elif isinstance(item, dict): - item.setdefault(DLT_METADATA_FIELD, {})[name] = value +# def append_dlt_meta(item: TBoundItem, name: str, value: Any) -> TBoundItem: +# if isinstance(item, abc.Sequence): +# for i in item: +# i.setdefault(DLT_METADATA_FIELD, {})[name] = value +# elif isinstance(item, dict): +# item.setdefault(DLT_METADATA_FIELD, {})[name] = value - return item +# return item -def with_table_name(item: TBoundItem, table_name: str) -> TBoundItem: - # normalize table name before adding - return append_dlt_meta(item, "table_name", table_name) +# def with_table_name(item: TBoundItem, table_name: str) -> TBoundItem: +# # normalize table name before adding +# return append_dlt_meta(item, "table_name", table_name) -def get_table_name(item: StrAny) -> Optional[str]: - if DLT_METADATA_FIELD in item: - meta: TEventDLTMeta = item[DLT_METADATA_FIELD] - return meta.get("table_name", None) - return None +# def get_table_name(item: StrAny) -> Optional[str]: +# if DLT_METADATA_FIELD in item: +# meta: TEventDLTMeta = item[DLT_METADATA_FIELD] +# return meta.get("table_name", None) +# return None def with_retry(max_retries: int = 3, retry_sleep: float = 1.0) -> Callable[[Callable[_TFunParams, TBoundItem]], Callable[_TFunParams, TBoundItem]]: diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 889220f72b..089aa8189c 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,18 +1,18 @@ from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from re import Pattern as _REPattern from typing import Callable, Dict, Any, Literal, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, Iterable, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin -try: - from typing_extensions import ParamSpec, TypeAlias, TypeGuard -except ImportError: - ParamSpec = lambda x: [x] # type: ignore +from typing_extensions import ParamSpec, TypeAlias, TypeGuard + if TYPE_CHECKING: from _typeshed import StrOrBytesPath + from typing_extensions import ParamSpec from typing import _TypedDict REPattern = _REPattern[str] else: StrOrBytesPath = Any from typing import _TypedDictMeta as _TypedDict REPattern = _REPattern + ParamSpec = lambda x: [x] DictStrAny: TypeAlias = Dict[str, Any] DictStrStr: TypeAlias = Dict[str, str] diff --git a/dlt/dbt_runner/configuration.py b/dlt/dbt_runner/configuration.py index 2cb17831c1..b0217e9267 100644 --- a/dlt/dbt_runner/configuration.py +++ b/dlt/dbt_runner/configuration.py @@ -5,11 +5,11 @@ from dlt.common.typing import StrAny, TSecretValue from dlt.common.configuration import resolve_configuration, configspec from dlt.common.configuration.providers import EnvironProvider -from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials +from dlt.common.configuration.specs import RunConfiguration, PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials @configspec -class DBTRunnerConfiguration(PoolRunnerConfiguration): +class DBTRunnerConfiguration(RunConfiguration, PoolRunnerConfiguration): pool_type: TPoolType = "none" stop_after_runs: int = 1 package_volume_path: str = "/var/local/app" diff --git a/dlt/extract/sources.py b/dlt/extract/sources.py index ac25825215..fb29e1ace4 100644 --- a/dlt/extract/sources.py +++ b/dlt/extract/sources.py @@ -1,6 +1,7 @@ import contextlib from copy import deepcopy import inspect +from collections.abc import Mapping as C_Mapping from typing import AsyncIterable, AsyncIterator, Coroutine, Dict, Generator, Iterable, Iterator, List, NamedTuple, Set, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any from dlt.common.exceptions import DltException @@ -8,25 +9,13 @@ from dlt.common.source import TFunHintTemplate, TDirectDataItem, TTableHintTemplate from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.schema.typing import TPartialTableSchema, TTableSchema, TTableSchemaColumns, TWriteDisposition +from dlt.common.schema.typing import TColumnSchema, TPartialTableSchema, TTableSchema, TTableSchemaColumns, TWriteDisposition from dlt.common.configuration.container import Container from dlt.common.pipeline import PipelineContext from dlt.extract.pipe import FilterItem, Pipe, CreatePipeException, PipeIterator -# class HintArgs(NamedTuple): -# table_name: TTableHintTemplate[str] -# parent_table_name: TTableHintTemplate[str] = None -# write_disposition: TTableHintTemplate[TWriteDisposition] = None -# columns: TTableHintTemplate[TTableSchemaColumns] = None - - -# def apply_args(args: HintArgs): -# pass - -# apply_args() - class TTableSchemaTemplate(TypedDict, total=False): name: TTableHintTemplate[str] description: TTableHintTemplate[str] @@ -115,6 +104,10 @@ def new_table_template( if not table_name: raise InvalidTableSchemaTemplate("Table template name must be a string or function taking TDataItem") # create a table schema template where hints can be functions taking TDataItem + if isinstance(columns, C_Mapping): + # new_table accepts a sequence + columns = columns.values() # type: ignore + new_template: TTableSchemaTemplate = new_table(table_name, parent_table_name, write_disposition=write_disposition, columns=columns) # type: ignore # if any of the hints is a function then name must be as well if any(callable(v) for k, v in new_template.items() if k != "name") and not callable(table_name): diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index 22a985a00d..bc13eb15ab 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import TYPE_CHECKING from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration, PoolRunnerConfiguration, CredentialsConfiguration, TPoolType @@ -11,3 +11,15 @@ class LoaderConfiguration(PoolRunnerConfiguration): pool_type: TPoolType = "thread" # mostly i/o (upload) so may be thread pool always_wipe_storage: bool = False # removes all data in the storage load_storage_config: LoadVolumeConfiguration = None + + if TYPE_CHECKING: + def __init__( + self, + pool_type: TPoolType = None, + workers: int = None, + exit_on_exception: bool = None, + is_single_run: bool = None, + always_wipe_storage: bool = None, + load_storage_config: LoadVolumeConfiguration = None + ) -> None: + ... diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index df4540aa0b..16f85dc880 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + from dlt.common.configuration import configspec from dlt.common.configuration.specs import LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, PoolRunnerConfiguration, TPoolType from dlt.common.destination import DestinationCapabilitiesContext @@ -10,3 +12,16 @@ class NormalizeConfiguration(PoolRunnerConfiguration): schema_storage_config: SchemaVolumeConfiguration normalize_storage_config: NormalizeVolumeConfiguration load_storage_config: LoadVolumeConfiguration + + if TYPE_CHECKING: + def __init__( + self, + pool_type: TPoolType = None, + workers: int = None, + exit_on_exception: bool = None, + is_single_run: bool = None, + schema_storage_config: SchemaVolumeConfiguration = None, + normalize_storage_config: NormalizeVolumeConfiguration = None, + load_storage_config: LoadVolumeConfiguration = None + ) -> None: + ... diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 9408db4037..24d312271d 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, cast from dlt.common.typing import TSecretValue, Any from dlt.common.configuration import with_config @@ -42,9 +42,10 @@ def pipeline( context = Container()[PipelineContext] # if pipeline instance is already active then return it, otherwise create a new one if context.is_activated(): - return context.pipeline() + return cast(Pipeline, context.pipeline()) print(kwargs["_last_dlt_config"].pipeline_name) + print(kwargs["_last_dlt_config"].runtime.log_level) # if working_dir not provided use temp folder if not working_dir: working_dir = get_default_working_dir() @@ -57,7 +58,6 @@ def pipeline( return p # setup default pipeline in the container -print("CONTEXT") Container()[PipelineContext] = PipelineContext(pipeline) diff --git a/dlt/pipeline/decorators.py b/dlt/pipeline/decorators.py index 38305d6582..fd0e3214a9 100644 --- a/dlt/pipeline/decorators.py +++ b/dlt/pipeline/decorators.py @@ -1,7 +1,7 @@ import inspect from types import ModuleType from makefun import wraps -from typing import Any, Dict, NamedTuple, Optional, Type +from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Type, Union, overload from dlt.common.configuration import with_config, get_fun_spec from dlt.common.configuration.specs import BaseConfiguration @@ -9,7 +9,7 @@ from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TTableSchemaColumns, TWriteDisposition from dlt.common.source import TTableHintTemplate, TFunHintTemplate -from dlt.common.typing import AnyFun, TFun +from dlt.common.typing import AnyFun, TFun, ParamSpec from dlt.common.utils import is_inner_function from dlt.extract.sources import DltResource, DltSource @@ -22,13 +22,24 @@ class SourceInfo(NamedTuple): _SOURCES: Dict[str, SourceInfo] = {} +TSourceFunParams = ParamSpec("TSourceFunParams") +TResourceFunParams = ParamSpec("TResourceFunParams") -def source(func: Optional[AnyFun] = None, /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None): + +@overload +def source(func: Callable[TSourceFunParams, Any], /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None) -> Callable[TSourceFunParams, DltSource]: + ... + +@overload +def source(func: None = ..., /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None) -> Callable[[Callable[TSourceFunParams, Any]], Callable[TSourceFunParams, DltSource]]: + ... + +def source(func: Optional[AnyFun] = None, /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None) -> Any: if name and schema: raise ArgumentsOverloadException("Source name cannot be set if schema is present") - def decorator(f: TFun) -> TFun: + def decorator(f: Callable[TSourceFunParams, Any]) -> Callable[TSourceFunParams, DltSource]: nonlocal schema, name # extract name @@ -69,7 +80,8 @@ def check_rv_type(rv: Any) -> None: # store the source information _SOURCES[_wrap.__qualname__] = SourceInfo(SPEC, _wrap, inspect.getmodule(f)) - return _wrap + # the typing is right, but makefun.wraps does not preserve signatures + return _wrap # type: ignore if func is None: # we're called with parens. @@ -82,23 +94,98 @@ def check_rv_type(rv: Any) -> None: return decorator(func) +# @source +# def reveal_1() -> None: +# pass + +# @source(name="revel") +# def reveal_2() -> None: +# pass + + +# def revel_3(v) -> int: +# pass + + +# reveal_type(reveal_1) +# reveal_type(reveal_1()) + +# reveal_type(reveal_2) +# reveal_type(reveal_2()) + +# reveal_type(source(revel_3)) +# reveal_type(source(revel_3)("s")) + +@overload +def resource( + data: Callable[TResourceFunParams, Any], + /, + name: str = None, + table_name_fun: TFunHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + selected: bool = True, + depends_on: DltResource = None, + spec: Type[BaseConfiguration] = None +) -> Callable[TResourceFunParams, DltResource]: + ... + +@overload +def resource( + data: None = ..., + /, + name: str = None, + table_name_fun: TFunHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + selected: bool = True, + depends_on: DltResource = None, + spec: Type[BaseConfiguration] = None +) -> Callable[[Callable[TResourceFunParams, Any]], Callable[TResourceFunParams, DltResource]]: + ... + + +# @overload +# def resource( +# data: Union[DltSource, DltResource, Sequence[DltSource], Sequence[DltResource]], +# / +# ) -> DltResource: +# ... + + +@overload +def resource( + data: Union[List[Any], Tuple[Any], Iterator[Any]], + /, + name: str = None, + table_name_fun: TFunHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + selected: bool = True, + depends_on: DltResource = None, + spec: Type[BaseConfiguration] = None +) -> DltResource: + ... + + def resource( data: Optional[Any] = None, /, - name: TTableHintTemplate[str] = None, + name: str = None, table_name_fun: TFunHintTemplate[str] = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TTableSchemaColumns] = None, selected: bool = True, depends_on: DltResource = None, - spec: Type[BaseConfiguration] = None): + spec: Type[BaseConfiguration] = None +) -> Any: - def make_resource(name, _data: Any) -> DltResource: - table_template = DltResource.new_table_template(table_name_fun or name, write_disposition=write_disposition, columns=columns) - return DltResource.from_data(_data, name, table_template, selected, depends_on) + def make_resource(_name: str, _data: Any) -> DltResource: + table_template = DltResource.new_table_template(table_name_fun or _name, write_disposition=write_disposition, columns=columns) + return DltResource.from_data(_data, _name, table_template, selected, depends_on) - def decorator(f: TFun) -> TFun: + def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunParams, DltResource]: resource_name = name or f.__name__ # if f is not a generator (does not yield) raise Exception @@ -124,7 +211,8 @@ def _wrap(*args: Any, **kwargs: Any) -> DltResource: if SPEC: _SOURCES[_wrap.__qualname__] = SourceInfo(SPEC, _wrap, inspect.getmodule(f)) - return _wrap + # the typing is right, but makefun.wraps does not preserve signatures + return _wrap # type: ignore # if data is callable or none use decorator @@ -143,3 +231,30 @@ def _get_source_for_inner_function(f: AnyFun) -> Optional[SourceInfo]: parts = f.__qualname__.split(".") parent_fun = ".".join(parts[:-2]) return _SOURCES.get(parent_fun) + + +# @resource +# def reveal_1() -> None: +# pass + +# @resource(name="revel") +# def reveal_2() -> None: +# pass + + +# def revel_3(v) -> int: +# pass + + +# reveal_type(reveal_1) +# reveal_type(reveal_1()) + +# reveal_type(reveal_2) +# reveal_type(reveal_2()) + +# reveal_type(resource(revel_3)) +# reveal_type(resource(revel_3)("s")) + + +# reveal_type(resource([], name="aaaa")) +# reveal_type(resource("aaaaa", name="aaaa")) \ No newline at end of file diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index a03617c507..0a2433a2d7 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -3,16 +3,15 @@ from copy import deepcopy from functools import wraps from collections.abc import Sequence as C_Sequence -from typing import Any, Callable, ClassVar, List, Iterable, Iterator, Generator, Mapping, NewType, Optional, Sequence, Tuple, Type, TypedDict, Union, get_type_hints, overload +from typing import Any, Callable, ClassVar, List, Iterable, Iterator, Generator, Mapping, NewType, Optional, Protocol, Sequence, Tuple, Type, TypeVar, TypedDict, Union, get_type_hints, overload from dlt.common import json, logger, signals from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext from dlt.common.runners.runnable import Runnable from dlt.common.schema.typing import TColumnSchema, TWriteDisposition -from dlt.common.source import DLT_METADATA_FIELD, TResolvableDataItem, with_table_name from dlt.common.storages.load_storage import LoadStorage -from dlt.common.typing import DictStrAny, StrAny, TFun, TSecretValue, TAny +from dlt.common.typing import ParamSpec, TFun, TSecretValue, TAny from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner from dlt.common.storages import LiveSchemaStorage, NormalizeStorage @@ -37,6 +36,49 @@ from dlt.pipeline.typing import TPipelineStep, TPipelineState from dlt.pipeline.configuration import StateInjectableContext +# TFunParams = ParamSpec("TFunParams") +# TSelfFun = TypeVar("TSelfFun", bound=Callable[["Pipeline", ], ]) +# class TSelfFun(Protocol): +# def __call__(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: ... + + +def with_state_sync(f: TFun) -> TFun: + + @wraps(f) + def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + # backup and restore state + with self._managed_state() as state: + # add the state to container as a context + with self._container.injectable_context(StateInjectableContext(state=state)): + return f(self, *args, **kwargs) + + return _wrap # type: ignore + +def with_schemas_sync(f: TFun) -> TFun: + + @wraps(f) + def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + for name in self._schema_storage.live_schemas: + # refresh live schemas in storage or import schema path + self._schema_storage.commit_live_schema(name) + return f(self, *args, **kwargs) + + return _wrap # type: ignore + +def with_config_namespace(namespaces: Tuple[str, ...]) -> Callable[[TFun], TFun]: + + def decorator(f: TFun) -> TFun: + + @wraps(f) + def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + # add namespace context to the container to be used by all configuration without explicit namespaces resolution + with inject_namespace(ConfigNamespacesContext(pipeline_name=self.pipeline_name, namespaces=namespaces)): + return f(self, *args, **kwargs) + + return _wrap # type: ignore + + return decorator + class Pipeline: @@ -70,7 +112,6 @@ def __init__( self._state: TPipelineState = {} # type: ignore self._pipeline_storage: FileStorage = None self._schema_storage: LiveSchemaStorage = None - # self._pool_config: PoolRunnerConfiguration = None self._schema_storage_config: SchemaVolumeConfiguration = None self._normalize_storage_config: NormalizeVolumeConfiguration = None self._load_storage_config: LoadVolumeConfiguration = None @@ -78,49 +119,11 @@ def __init__( initialize_runner(self.runtime_config) self._configure(pipeline_name, working_dir, import_schema_path, export_schema_path, always_drop_pipeline) - def with_state_sync(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - # backup and restore state - with self._managed_state() as state: - # add the state to container as a context - with self._container.injectable_context(StateInjectableContext(state=state)): - return f(self, *args, **kwargs) - - return _wrap - - def with_schemas_sync(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - for name in self._schema_storage.live_schemas: - # refresh live schemas in storage or import schema path - self._schema_storage.commit_live_schema(name) - return f(self, *args, **kwargs) - - return _wrap - - def with_config_namespace(namespaces: Tuple[str, ...]) -> Callable[[TFun], TFun]: - - def decorator(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - # add namespace context to the container to be used by all configuration without explicit namespaces resolution - with inject_namespace(ConfigNamespacesContext(pipeline_name=self.pipeline_name, namespaces=namespaces)): - return f(self, *args, **kwargs) - - return _wrap - - return decorator - - def drop(self) -> "Pipeline": """Deletes existing pipeline state, schemas and drops datasets at the destination if present""" # drop the data for all known schemas for schema in self._schema_storage: - with self._get_destination_client(schema) as client: + with self._get_destination_client(self._schema_storage.load_schema(schema)) as client: client.initialize_storage(wipe_data=True) # reset the pipeline working dir self._create_pipeline() @@ -182,7 +185,10 @@ def extract( # return table_name or parent_table_name or write_disposition or schema def apply_hint_args(resource: DltResource) -> None: - resource.apply_hints(table_name, parent_table_name, write_disposition, columns) + columns_dict = None + if columns: + columns_dict = {c["name"]:c for c in columns} + resource.apply_hints(table_name, parent_table_name, write_disposition, columns_dict) def choose_schema() -> Schema: if schema: @@ -241,7 +247,7 @@ def item_to_source(data_item: Any) -> DltSource: @with_schemas_sync @with_config_namespace(("normalize",)) - def normalize(self, workers: int = 1, dry_run: bool = False) -> None: + def normalize(self, workers: int = 1) -> None: if is_interactive() and workers > 1: raise NotImplementedError("Do not use normalize workers in interactive mode ie. in notebook") # check if any schema is present, if not then no data was extracted @@ -276,7 +282,7 @@ def load( credentials: Any = None, # raise_on_failed_jobs = False, # raise_on_incompatible_schema = False, - always_wipe_storage = False, + always_wipe_storage: bool = False, *, workers: int = 20 ) -> None: @@ -343,15 +349,15 @@ def list_extracted_resources(self) -> Sequence[str]: return self._get_normalize_storage().list_files_to_normalize_sorted() def list_normalized_load_packages(self) -> Sequence[str]: - return self._get_load_storage().load_storage.list_packages() + return self._get_load_storage().list_packages() def list_completed_load_packages(self) -> Sequence[str]: - return self._get_load_storage().load_storage.list_completed_packages() + return self._get_load_storage().list_completed_packages() def list_failed_jobs_in_package(self, load_id: str) -> Sequence[Tuple[str, str]]: storage = self._get_load_storage() failed_jobs: List[Tuple[str, str]] = [] - for file in storage.load_storage.list_completed_failed_jobs(load_id): + for file in storage.list_completed_failed_jobs(load_id): if not file.endswith(".exception"): try: failed_message = storage.storage.load(file + ".exception") @@ -423,7 +429,7 @@ def _restore_pipeline(self) -> None: def _restore_state(self) -> None: self._state.clear() # type: ignore restored_state: TPipelineState = json.loads(self._pipeline_storage.load(Pipeline.STATE_FILE)) - self._state.update(restored_state) + self._state.update(restored_state) # type: ignore def _extract_source(self, source: DltSource, max_parallel_items: int, workers: int) -> None: # discover the schema from source @@ -460,11 +466,6 @@ def _run_step_in_pool(self, step: TPipelineStep, runnable: Runnable[Any], config signals.raise_if_signalled() def _run_f_in_pool(self, run_f: Callable[..., Any], config: PoolRunnerConfiguration) -> int: - # internal runners should work in single mode - self._loader_instance.config.is_single_run = True - self._loader_instance.config.exit_on_exception = True - self._normalize_instance.config.is_single_run = True - self._normalize_instance.config.exit_on_exception = True def _run(_: Any) -> TRunMetrics: rv = run_f() @@ -539,20 +540,15 @@ def _managed_state(self) -> Iterator[TPipelineState]: except Exception: # restore old state self._state.clear() # type: ignore - self._state.update(backup_state) + self._state.update(backup_state) # type: ignore raise else: # update state props for prop in Pipeline.STATE_PROPS: - self._state[prop] = getattr(self, prop) + self._state[prop] = getattr(self, prop) # type: ignore # compare backup and new state, save only if different new_state = json.dumps(self._state) old_state = json.dumps(backup_state) # persist old state if new_state != old_state: self._pipeline_storage.save(Pipeline.STATE_FILE, new_state) - - @property - def has_pending_loads(self) -> bool: - # TODO: check if has pending normalizer and loader data - pass From 3ce749e5f3e3c945974d1158e84efe1f6db54554 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 28 Oct 2022 17:15:15 +0200 Subject: [PATCH 57/66] moves source decorators to extract --- dlt/{pipeline => extract}/decorators.py | 1 + 1 file changed, 1 insertion(+) rename dlt/{pipeline => extract}/decorators.py (99%) diff --git a/dlt/pipeline/decorators.py b/dlt/extract/decorators.py similarity index 99% rename from dlt/pipeline/decorators.py rename to dlt/extract/decorators.py index fd0e3214a9..e385fb4b3a 100644 --- a/dlt/pipeline/decorators.py +++ b/dlt/extract/decorators.py @@ -11,6 +11,7 @@ from dlt.common.source import TTableHintTemplate, TFunHintTemplate from dlt.common.typing import AnyFun, TFun, ParamSpec from dlt.common.utils import is_inner_function + from dlt.extract.sources import DltResource, DltSource From daba606660fd73451068d18dc8766ea1b05f1098 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 30 Oct 2022 01:30:00 +0200 Subject: [PATCH 58/66] adds restore pipeline and better state management, with cli support --- dlt/cli/dlt.py | 24 ++++--- dlt/pipeline/__init__.py | 58 +++++++++++------ dlt/pipeline/exceptions.py | 21 ++---- dlt/pipeline/pipeline.py | 128 ++++++++++++++++++++++--------------- dlt/pipeline/typing.py | 4 +- 5 files changed, 137 insertions(+), 98 deletions(-) diff --git a/dlt/cli/dlt.py b/dlt/cli/dlt.py index 774e78d1ba..07dc6ca729 100644 --- a/dlt/cli/dlt.py +++ b/dlt/cli/dlt.py @@ -8,7 +8,7 @@ from dlt.common.schema import Schema from dlt.common.typing import DictStrAny -from dlt.pipeline import pipeline +from dlt.pipeline import pipeline, restore def add_pool_cli_arguments(parser: argparse.ArgumentParser) -> None: @@ -40,7 +40,7 @@ def main() -> None: schema.add_argument("--remove-defaults", action="store_true", help="Does not show default hint values") pipe_cmd = subparsers.add_parser("pipeline", help="Operations on the pipelines") pipe_cmd.add_argument("name", help="Pipeline name") - pipe_cmd.add_argument("operation", choices=["failed_loads"], default="failed_loads", help="Show failed loads for a pipeline") + pipe_cmd.add_argument("operation", choices=["failed_loads", "drop"], default="failed_loads", help="Show failed loads for a pipeline") pipe_cmd.add_argument("--workdir", help="Pipeline working directory", default=None) # TODO: consider using fire: https://github.com/google/python-fire @@ -73,13 +73,19 @@ def main() -> None: elif args.command == "pipeline": # from dlt.load import dummy - p = pipeline(pipeline_name=args.name, working_dir=args.workdir, destination="dummy") - print(f"Checking pipeline {p.pipeline_name} ({args.name}) in {p.working_dir} ({args.workdir}) with state {p._state}") - completed_loads = p.list_completed_load_packages() - for load_id in completed_loads: - print(f"Checking failed jobs in {load_id}") - for job, failed_message in p.list_failed_jobs_in_package(load_id): - print(f"JOB: {job}\nMSG: {failed_message}") + p = restore(pipeline_name=args.name, working_dir=args.workdir) + print(f"Found pipeline {p.pipeline_name} ({args.name}) in {p.working_dir} ({args.workdir}) with state {p._get_state()}") + + if args.operation == "failed_loads": + completed_loads = p.list_completed_load_packages() + for load_id in completed_loads: + print(f"Checking failed jobs in load id '{load_id}'") + for job, failed_message in p.list_failed_jobs_in_package(load_id): + print(f"JOB: {os.path.abspath(job)}\nMSG: {failed_message}") + + if args.operation == "drop": + p.drop() + exit(0) else: parser.print_help() diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 24d312271d..043593fb60 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -3,26 +3,12 @@ from dlt.common.typing import TSecretValue, Any from dlt.common.configuration import with_config from dlt.common.configuration.container import Container -from dlt.common.destination import DestinationReference, resolve_destination_reference +from dlt.common.destination import DestinationReference from dlt.common.pipeline import PipelineContext, get_default_working_dir from dlt.pipeline.configuration import PipelineConfiguration from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.decorators import source, resource - - -# @overload -# def configure(self, -# pipeline_name: str = None, -# working_dir: str = None, -# pipeline_secret: TSecretValue = None, -# drop_existing_data: bool = False, -# import_schema_path: str = None, -# export_schema_path: str = None, -# destination_name: str = None, -# log_level: str = "INFO" -# ) -> None: -# ... +from dlt.extract.decorators import source, resource @with_config(spec=PipelineConfiguration, auto_namespace=True) @@ -49,18 +35,52 @@ def pipeline( # if working_dir not provided use temp folder if not working_dir: working_dir = get_default_working_dir() - destination = resolve_destination_reference(destination) + destination = DestinationReference.from_name(destination) # create new pipeline instance - p = Pipeline(pipeline_name, working_dir, pipeline_secret, destination, dataset_name, import_schema_path, export_schema_path, always_drop_pipeline, kwargs["runtime"]) + p = Pipeline(pipeline_name, working_dir, pipeline_secret, destination, dataset_name, import_schema_path, export_schema_path, always_drop_pipeline, False, kwargs["runtime"]) # set it as current pipeline Container()[PipelineContext].activate(p) return p + +def restore( + pipeline_name: str = None, + working_dir: str = None, + pipeline_secret: TSecretValue = None +) -> Pipeline: + + _pipeline_name = pipeline_name + _working_dir = working_dir + + @with_config(spec=PipelineConfiguration, auto_namespace=True) + def _restore( + pipeline_name: str, + working_dir: str, + pipeline_secret: TSecretValue, + always_drop_pipeline: bool = False, + **kwargs: Any + ) -> Pipeline: + # use the outer pipeline name and working dir to override those from config in order to restore the requested state + pipeline_name = _pipeline_name or pipeline_name + working_dir = _working_dir or working_dir + + # if working_dir not provided use temp folder + if not working_dir: + working_dir = get_default_working_dir() + # create new pipeline instance + p = Pipeline(pipeline_name, working_dir, pipeline_secret, None, None, None, None, always_drop_pipeline, True, kwargs["runtime"]) + # set it as current pipeline + Container()[PipelineContext].activate(p) + return p + + return _restore(pipeline_name, working_dir, pipeline_secret) + + # setup default pipeline in the container Container()[PipelineContext] = PipelineContext(pipeline) def run(source: Any, destination: Union[None, str, DestinationReference] = None) -> Pipeline: - destination = resolve_destination_reference(destination) + destination = DestinationReference.from_name(destination) return pipeline().run(source=source, destination=destination) diff --git a/dlt/pipeline/exceptions.py b/dlt/pipeline/exceptions.py index c8598f8117..690ed9023c 100644 --- a/dlt/pipeline/exceptions.py +++ b/dlt/pipeline/exceptions.py @@ -28,9 +28,9 @@ def _to_pip_install(self) -> str: return "\n".join([f"pip install {d}" for d in self.dependencies]) -class NoPipelineException(PipelineException): - def __init__(self) -> None: - super().__init__("Please create or restore pipeline before using this function") +# class NoPipelineException(PipelineException): +# def __init__(self) -> None: +# super().__init__("Please create or restore pipeline before using this function") class PipelineConfigMissing(PipelineException): @@ -54,8 +54,9 @@ def __init__(self, config_elem: str, step: TPipelineStep, help: str = None) -> N class CannotRestorePipelineException(PipelineException): - def __init__(self, reason: str) -> None: - super().__init__(reason) + def __init__(self, pipeline_name: str, working_dir: str, reason: str) -> None: + msg = f"Pipeline with name {pipeline_name} in working directory {working_dir} could not be restored: {reason}" + super().__init__(msg) class SqlClientNotAvailable(PipelineException): @@ -63,16 +64,6 @@ def __init__(self, client_type: str) -> None: super().__init__(f"SQL Client not available in {client_type}") -class InvalidIteratorException(PipelineException): - def __init__(self, iterator: Any) -> None: - super().__init__(f"Unsupported source iterator or iterable type: {type(iterator).__name__}") - - -class InvalidItemException(PipelineException): - def __init__(self, item: Any) -> None: - super().__init__(f"Source yielded unsupported item type: {type(item).__name}. Only dictionaries, sequences and deferred items allowed.") - - class PipelineStepFailed(PipelineException): def __init__(self, step: TPipelineStep, exception: BaseException, run_metrics: TRunMetrics) -> None: self.stage = step diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 0a2433a2d7..178cbbbc29 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1,9 +1,8 @@ import os from contextlib import contextmanager -from copy import deepcopy from functools import wraps from collections.abc import Sequence as C_Sequence -from typing import Any, Callable, ClassVar, List, Iterable, Iterator, Generator, Mapping, NewType, Optional, Protocol, Sequence, Tuple, Type, TypeVar, TypedDict, Union, get_type_hints, overload +from typing import Any, Callable, ClassVar, List, Iterator, Mapping, Sequence, Tuple, get_type_hints, overload from dlt.common import json, logger, signals from dlt.common.configuration.container import Container @@ -11,7 +10,7 @@ from dlt.common.runners.runnable import Runnable from dlt.common.schema.typing import TColumnSchema, TWriteDisposition from dlt.common.storages.load_storage import LoadStorage -from dlt.common.typing import ParamSpec, TFun, TSecretValue, TAny +from dlt.common.typing import ParamSpec, TFun, TSecretValue from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner from dlt.common.storages import LiveSchemaStorage, NormalizeStorage @@ -19,28 +18,23 @@ from dlt.common.configuration import inject_namespace from dlt.common.configuration.specs import RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, LoadVolumeConfiguration, PoolRunnerConfiguration from dlt.common.destination import DestinationCapabilitiesContext, DestinationReference, JobClientBase, DestinationClientConfiguration, DestinationClientDwhConfiguration -from dlt.common.schema import Schema, utils as schema_utils +from dlt.common.schema import Schema from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import is_interactive -from dlt.extract.extract import ExtractorStorage, extract -from dlt.load.job_client_impl import SqlJobClientBase +from dlt.extract.extract import ExtractorStorage, extract +from dlt.extract.source import DltResource, DltSource from dlt.normalize import Normalize +from dlt.normalize.configuration import NormalizeConfiguration from dlt.load.sql_client import SqlClientBase +from dlt.load.job_client_impl import SqlJobClientBase from dlt.load.configuration import LoaderConfiguration from dlt.load import Load -from dlt.normalize.configuration import NormalizeConfiguration -from dlt.pipeline.exceptions import PipelineConfigMissing, MissingDependencyException, PipelineStepFailed, SqlClientNotAvailable -from dlt.extract.sources import DltResource, DltSource, TTableSchemaTemplate +from dlt.pipeline.exceptions import CannotRestorePipelineException, PipelineConfigMissing, MissingDependencyException, PipelineStepFailed, SqlClientNotAvailable from dlt.pipeline.typing import TPipelineStep, TPipelineState from dlt.pipeline.configuration import StateInjectableContext -# TFunParams = ParamSpec("TFunParams") -# TSelfFun = TypeVar("TSelfFun", bound=Callable[["Pipeline", ], ]) -# class TSelfFun(Protocol): -# def __call__(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: ... - def with_state_sync(f: TFun) -> TFun: @@ -89,6 +83,9 @@ class Pipeline: default_schema_name: str always_drop_pipeline: bool working_dir: str + pipeline_root: str + destination: DestinationReference + dataset_name: str def __init__( self, @@ -100,16 +97,14 @@ def __init__( import_schema_path: str, export_schema_path: str, always_drop_pipeline: bool, + must_restore_pipeline: bool, runtime: RunConfiguration ) -> None: self.pipeline_secret = pipeline_secret self.runtime_config = runtime - self.destination = destination - self.dataset_name = dataset_name - self.root_folder: str = None self._container = Container() - self._state: TPipelineState = {} # type: ignore + # self._state: TPipelineState = {} # type: ignore self._pipeline_storage: FileStorage = None self._schema_storage: LiveSchemaStorage = None self._schema_storage_config: SchemaVolumeConfiguration = None @@ -117,14 +112,26 @@ def __init__( self._load_storage_config: LoadVolumeConfiguration = None initialize_runner(self.runtime_config) - self._configure(pipeline_name, working_dir, import_schema_path, export_schema_path, always_drop_pipeline) + # initialize pipeline working dir + self._init_working_dir(pipeline_name, working_dir) + # initialize or restore state + with self._managed_state(): + # see if state didn't change the pipeline name + if pipeline_name != self.pipeline_name: + raise CannotRestorePipelineException(pipeline_name, working_dir, f"working directory contains state for pipeline with name {self.pipeline_name}") + # at this moment state is recovered so we overwrite the state with the values from init + self.destination = destination or self.destination # changing the destination could be dangerous if pipeline has not loaded items + self.dataset_name = dataset_name or self.dataset_name + self.always_drop_pipeline = always_drop_pipeline + self._configure(import_schema_path, export_schema_path, must_restore_pipeline) def drop(self) -> "Pipeline": """Deletes existing pipeline state, schemas and drops datasets at the destination if present""" - # drop the data for all known schemas - for schema in self._schema_storage: - with self._get_destination_client(self._schema_storage.load_schema(schema)) as client: - client.initialize_storage(wipe_data=True) + if self.destination: + # drop the data for all known schemas + for schema in self._schema_storage: + with self._get_destination_client(self._schema_storage.load_schema(schema)) as client: + client.initialize_storage(wipe_data=True) # reset the pipeline working dir self._create_pipeline() # clone the pipeline @@ -137,6 +144,7 @@ def drop(self) -> "Pipeline": self._schema_storage.config.import_schema_path, self._schema_storage.config.export_schema_path, self.always_drop_pipeline, + True, self.runtime_config ) @@ -170,7 +178,7 @@ def drop(self) -> "Pipeline": def extract( self, data: Any, - table_name: str, + table_name: str = None, parent_table_name: str = None, write_disposition: TWriteDisposition = None, columns: Sequence[TColumnSchema] = None, @@ -242,7 +250,8 @@ def item_to_source(data_item: Any) -> DltSource: for s in sources: self._extract_source(s, max_parallel_items, workers) except Exception as exc: - raise PipelineStepFailed("extract", self.last_run_exception, runner.LAST_RUN_METRICS) from exc + # TODO: provide metrics from extractor + raise PipelineStepFailed("extract", exc, runner.LAST_RUN_METRICS) from exc @with_schemas_sync @@ -363,7 +372,7 @@ def list_failed_jobs_in_package(self, load_id: str) -> Sequence[Tuple[str, str]] failed_message = storage.storage.load(file + ".exception") except FileNotFoundError: failed_message = None - failed_jobs.append((file, failed_message)) + failed_jobs.append((storage.storage.make_full_path(file), failed_message)) return failed_jobs def sync_schema(self, schema_name: str = None) -> None: @@ -385,30 +394,32 @@ def _get_load_storage(self) -> LoadStorage: caps = self._get_destination_capabilities() return LoadStorage(True, caps.preferred_loader_file_format, caps.supported_loader_file_formats, self._load_storage_config) - @with_state_sync - def _configure(self, pipeline_name: str, working_dir: str, import_schema_path: str, export_schema_path: str, always_drop_pipeline: bool) -> None: + def _init_working_dir(self, pipeline_name: str, working_dir: str) -> None: self.pipeline_name = pipeline_name self.working_dir = working_dir - self.always_drop_pipeline = always_drop_pipeline - # compute the folder that keeps all of the pipeline state FileStorage.validate_file_name_component(self.pipeline_name) - self.root_folder = os.path.join(self.working_dir, self.pipeline_name) + self.pipeline_root = os.path.join(working_dir, pipeline_name) + # create pipeline working dir + self._pipeline_storage = FileStorage(self.pipeline_root, makedirs=False) + + def _configure(self, import_schema_path: str, export_schema_path: str, must_restore_pipeline: bool) -> None: # create default configs - # self._pool_config = PoolRunnerConfiguration(is_single_run=True, exit_on_exception=True) self._schema_storage_config = SchemaVolumeConfiguration( - schema_volume_path=os.path.join(self.root_folder, "schemas"), + schema_volume_path=os.path.join(self.pipeline_root, "schemas"), import_schema_path=import_schema_path, export_schema_path=export_schema_path ) - self._normalize_storage_config = NormalizeVolumeConfiguration(normalize_volume_path=os.path.join(self.root_folder, "normalize")) - self._load_storage_config = LoadVolumeConfiguration(load_volume_path=os.path.join(self.root_folder, "load"),) + self._normalize_storage_config = NormalizeVolumeConfiguration(normalize_volume_path=os.path.join(self.pipeline_root, "normalize")) + self._load_storage_config = LoadVolumeConfiguration(load_volume_path=os.path.join(self.pipeline_root, "load"),) - # create pipeline working dir - self._pipeline_storage = FileStorage(self.root_folder, makedirs=False) + # are we running again? + has_state = self._pipeline_storage.has_file(Pipeline.STATE_FILE) + if must_restore_pipeline and not has_state: + raise CannotRestorePipelineException(self.pipeline_name, self.working_dir, f"the pipeline was not found in {self.pipeline_root}.") # restore pipeline if folder exists and contains state - if self._pipeline_storage.has_file(Pipeline.STATE_FILE) and not always_drop_pipeline: + if has_state and (not self.always_drop_pipeline or must_restore_pipeline): self._restore_pipeline() else: # this will erase the existing working folder @@ -424,12 +435,7 @@ def _create_pipeline(self) -> None: self._pipeline_storage.create_folder("", exists_ok=False) def _restore_pipeline(self) -> None: - self._restore_state() - - def _restore_state(self) -> None: - self._state.clear() # type: ignore - restored_state: TPipelineState = json.loads(self._pipeline_storage.load(Pipeline.STATE_FILE)) - self._state.update(restored_state) # type: ignore + pass def _extract_source(self, source: DltSource, max_parallel_items: int, workers: int) -> None: # discover the schema from source @@ -528,27 +534,43 @@ def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: def _get_dataset_name(self) -> str: return self.dataset_name or self.pipeline_name + def _get_state(self) -> TPipelineState: + try: + state: TPipelineState = json.loads(self._pipeline_storage.load(Pipeline.STATE_FILE)) + except FileNotFoundError: + state = {} + return state + @contextmanager def _managed_state(self) -> Iterator[TPipelineState]: + # load current state + state = self._get_state() # write props to pipeline variables for prop in Pipeline.STATE_PROPS: - setattr(self, prop, self._state.get(prop)) - # backup the state - backup_state = deepcopy(self._state) + setattr(self, prop, state.get(prop)) + if "destination" in state: + self.destination = DestinationReference.from_name(self.destination) + try: - yield self._state + yield state except Exception: # restore old state - self._state.clear() # type: ignore - self._state.update(backup_state) # type: ignore + # currently do nothing - state is not preserved in memory, only saved raise else: # update state props for prop in Pipeline.STATE_PROPS: - self._state[prop] = getattr(self, prop) # type: ignore + state[prop] = getattr(self, prop) # type: ignore + if self.destination: + state["destination"] = self.destination.__name__ + + # load state from storage to be merged with pipeline changes, currently we assume no parallel changes # compare backup and new state, save only if different - new_state = json.dumps(self._state) - old_state = json.dumps(backup_state) + backup_state = self._get_state() + print(state) + print(backup_state) + new_state = json.dumps(state, sort_keys=True) + old_state = json.dumps(backup_state, sort_keys=True) # persist old state if new_state != old_state: self._pipeline_storage.save(Pipeline.STATE_FILE, new_state) diff --git a/dlt/pipeline/typing.py b/dlt/pipeline/typing.py index adb9f2f6e2..8a1248e93b 100644 --- a/dlt/pipeline/typing.py +++ b/dlt/pipeline/typing.py @@ -3,11 +3,11 @@ TPipelineStep = Literal["extract", "normalize", "load"] -class TPipelineState(TypedDict): +class TPipelineState(TypedDict, total=False): pipeline_name: str dataset_name: str default_schema_name: Optional[str] - # destination: Optional[str] + destination: Optional[str] # TSourceState = NewType("TSourceState", DictStrAny) From dab58bf24a033b76d3cbb6ee9bc1da2f2e7d2936 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 30 Oct 2022 01:31:44 +0200 Subject: [PATCH 59/66] fixes dependent resources handling, partially adds missing exceptions --- dlt/common/pipeline.py | 10 ++- dlt/common/source.py | 88 ------------------- dlt/common/typing.py | 7 +- dlt/extract/decorators.py | 106 +++++++++++++++++++---- dlt/extract/exceptions.py | 80 +++++++++++++++++ dlt/extract/extract.py | 6 +- dlt/extract/pipe.py | 78 ++++++++--------- dlt/extract/{sources.py => source.py} | 120 +++++++++++++------------- dlt/extract/typing.py | 22 +++++ 9 files changed, 303 insertions(+), 214 deletions(-) delete mode 100644 dlt/common/source.py rename dlt/extract/{sources.py => source.py} (75%) create mode 100644 dlt/extract/typing.py diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 3be316dad6..13781cc23a 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -26,16 +26,18 @@ def run( @configspec(init=True) class PipelineContext(ContainerInjectableContext): + # TODO: declare unresolvable generic types that will be allowed by configpec _deferred_pipeline: Any _pipeline: Any can_create_default: ClassVar[bool] = False def pipeline(self) -> SupportsPipeline: + """Creates or returns exiting pipeline""" if not self._pipeline: # delayed pipeline creation self._pipeline = self._deferred_pipeline() - return self._pipeline + return self._pipeline # type: ignore def activate(self, pipeline: SupportsPipeline) -> None: self._pipeline = pipeline @@ -44,10 +46,16 @@ def is_activated(self) -> bool: return self._pipeline is not None def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline]) -> None: + """Initialize the context with a function returning the Pipeline object to allow creation on first use""" self._deferred_pipeline = deferred_pipeline def get_default_working_dir() -> str: + """ Gets default working dir of the pipeline, which may be + 1. in user home directory ~/.dlt/pipelines/ + 2. if current user is root in /var/dlt/pipelines + 3. if current user does not have a home directory in /tmp/dlt/pipelines + """ if os.geteuid() == 0: # we are root so use standard /var return os.path.join("/var", "dlt", "pipelines") diff --git a/dlt/common/source.py b/dlt/common/source.py deleted file mode 100644 index 0d567fdf0f..0000000000 --- a/dlt/common/source.py +++ /dev/null @@ -1,88 +0,0 @@ -from collections import abc -from functools import wraps -from typing import Any, Callable, Optional, Sequence, TypeVar, Union, TypedDict, List, Awaitable - -from dlt.common import logger -from dlt.common.time import sleep -from dlt.common.typing import ParamSpec, TDataItem - - -# possible types of items yielded by the source -# 1. document (mapping from str to any type) -# 2. Iterable (ie list) on the mapping above for returning many documents with single yield -TItem = Union[TDataItem, Sequence[TDataItem]] -TBoundItem = TypeVar("TBoundItem", bound=TItem) -TDeferred = Callable[[], TBoundItem] - -_TFunParams = ParamSpec("_TFunParams") - -# TODO: cleanup those types -TDirectDataItem = Union[TDataItem, List[TDataItem]] -TDeferredDataItem = Callable[[], TDirectDataItem] -TAwaitableDataItem = Awaitable[TDirectDataItem] -TResolvableDataItem = Union[TDirectDataItem, TDeferredDataItem, TAwaitableDataItem] - -TDynHintType = TypeVar("TDynHintType") -TFunHintTemplate = Callable[[TDataItem], TDynHintType] -TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]] - -# name of dlt metadata as part of the item -# DLT_METADATA_FIELD = "_dlt_meta" - - -# class TEventDLTMeta(TypedDict, total=False): -# table_name: str # a root table in which store the event - - -# def append_dlt_meta(item: TBoundItem, name: str, value: Any) -> TBoundItem: -# if isinstance(item, abc.Sequence): -# for i in item: -# i.setdefault(DLT_METADATA_FIELD, {})[name] = value -# elif isinstance(item, dict): -# item.setdefault(DLT_METADATA_FIELD, {})[name] = value - -# return item - - -# def with_table_name(item: TBoundItem, table_name: str) -> TBoundItem: -# # normalize table name before adding -# return append_dlt_meta(item, "table_name", table_name) - - -# def get_table_name(item: StrAny) -> Optional[str]: -# if DLT_METADATA_FIELD in item: -# meta: TEventDLTMeta = item[DLT_METADATA_FIELD] -# return meta.get("table_name", None) -# return None - - -def with_retry(max_retries: int = 3, retry_sleep: float = 1.0) -> Callable[[Callable[_TFunParams, TBoundItem]], Callable[_TFunParams, TBoundItem]]: - - def decorator(f: Callable[_TFunParams, TBoundItem]) -> Callable[_TFunParams, TBoundItem]: - - def _wrap(*args: Any, **kwargs: Any) -> TBoundItem: - attempts = 0 - while True: - try: - return f(*args, **kwargs) - except Exception as exc: - if attempts == max_retries: - raise - attempts += 1 - logger.warning(f"Exception {exc} in iterator, retrying {attempts} / {max_retries}") - sleep(retry_sleep) - - return _wrap - - return decorator - - -def defer_iterator(f: Callable[_TFunParams, TBoundItem]) -> Callable[_TFunParams, TDeferred[TBoundItem]]: - - @wraps(f) - def _wrap(*args: Any, **kwargs: Any) -> TDeferred[TBoundItem]: - def _curry() -> TBoundItem: - return f(*args, **kwargs) - return _curry - - return _wrap diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 089aa8189c..72d7847398 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,7 +1,7 @@ from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from re import Pattern as _REPattern -from typing import Callable, Dict, Any, Literal, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, Iterable, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin -from typing_extensions import ParamSpec, TypeAlias, TypeGuard +from typing import Callable, Dict, Any, Literal, List, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin +from typing_extensions import TypeAlias if TYPE_CHECKING: from _typeshed import StrOrBytesPath @@ -24,7 +24,8 @@ TAny = TypeVar("TAny", bound=Any) TAnyClass = TypeVar("TAnyClass", bound=object) TSecretValue = NewType("TSecretValue", str) # represent secret value ie. coming from Kubernetes/Docker secrets or other providers -TDataItem: TypeAlias = Any # a single data item extracted from data source, normalized and loaded +TDataItem: TypeAlias = object # a single data item as extracted from data source +TDataItems: TypeAlias = Union[TDataItem, List[TDataItem]] # a single or many data items as extracted from the data source ConfigValue: None = None # a value of type None indicating argument that may be injected by config provider diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index e385fb4b3a..2316cd4b29 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -1,18 +1,18 @@ import inspect from types import ModuleType from makefun import wraps -from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Type, Union, overload +from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union, overload from dlt.common.configuration import with_config, get_fun_spec from dlt.common.configuration.specs import BaseConfiguration from dlt.common.exceptions import ArgumentsOverloadException from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TTableSchemaColumns, TWriteDisposition -from dlt.common.source import TTableHintTemplate, TFunHintTemplate -from dlt.common.typing import AnyFun, TFun, ParamSpec +from dlt.common.typing import AnyFun, ParamSpec, TDataItems from dlt.common.utils import is_inner_function -from dlt.extract.sources import DltResource, DltSource +from dlt.extract.typing import TTableHintTemplate, TFunHintTemplate +from dlt.extract.source import DltResource, DltSource class SourceInfo(NamedTuple): @@ -37,17 +37,19 @@ def source(func: None = ..., /, name: str = None, schema: Schema = None, spec: T def source(func: Optional[AnyFun] = None, /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None) -> Any: - if name and schema: - raise ArgumentsOverloadException("Source name cannot be set if schema is present") + # if name and schema: + # raise ArgumentsOverloadException( + # "source name cannot be set if schema is present", + # "source", + # "You can provide either the Schema instance directly in `schema` argument or the name of ") def decorator(f: Callable[TSourceFunParams, Any]) -> Callable[TSourceFunParams, DltSource]: nonlocal schema, name - # extract name - if schema: - name = schema.name - else: - name = name or f.__name__ + # source name is passed directly or taken from decorated function name + name = name or f.__name__ + + if not schema: # create or load default schema # TODO: we need a convention to load ie. load the schema from file with name_schema.yaml schema = Schema(name) @@ -190,8 +192,8 @@ def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunPara resource_name = name or f.__name__ # if f is not a generator (does not yield) raise Exception - if not inspect.isgeneratorfunction(inspect.unwrap(f)): - raise ResourceFunNotGenerator() + # if not inspect.isgeneratorfunction(inspect.unwrap(f)): + # raise ResourceFunNotGenerator() # do not inject config values for inner functions, we assume that they are part of the source SPEC: Type[BaseConfiguration] = None @@ -204,16 +206,16 @@ def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunPara # get spec for wrapped function SPEC = get_fun_spec(conf_f) - @wraps(conf_f, func_name=resource_name) - def _wrap(*args: Any, **kwargs: Any) -> DltResource: - return make_resource(resource_name, f(*args, **kwargs)) + # @wraps(conf_f, func_name=resource_name) + # def _wrap(*args: Any, **kwargs: Any) -> DltResource: + # return make_resource(resource_name, f(*args, **kwargs)) # store the standalone resource information if SPEC: - _SOURCES[_wrap.__qualname__] = SourceInfo(SPEC, _wrap, inspect.getmodule(f)) + _SOURCES[f.__qualname__] = SourceInfo(SPEC, f, inspect.getmodule(f)) # the typing is right, but makefun.wraps does not preserve signatures - return _wrap # type: ignore + return make_resource(resource_name, f) # type: ignore # if data is callable or none use decorator @@ -258,4 +260,70 @@ def _get_source_for_inner_function(f: AnyFun) -> Optional[SourceInfo]: # reveal_type(resource([], name="aaaa")) -# reveal_type(resource("aaaaa", name="aaaa")) \ No newline at end of file +# reveal_type(resource("aaaaa", name="aaaa")) + +# name of dlt metadata as part of the item +# DLT_METADATA_FIELD = "_dlt_meta" + + +# class TEventDLTMeta(TypedDict, total=False): +# table_name: str # a root table in which store the event + + +# def append_dlt_meta(item: TBoundItem, name: str, value: Any) -> TBoundItem: +# if isinstance(item, abc.Sequence): +# for i in item: +# i.setdefault(DLT_METADATA_FIELD, {})[name] = value +# elif isinstance(item, dict): +# item.setdefault(DLT_METADATA_FIELD, {})[name] = value + +# return item + + +# def with_table_name(item: TBoundItem, table_name: str) -> TBoundItem: +# # normalize table name before adding +# return append_dlt_meta(item, "table_name", table_name) + + +# def get_table_name(item: StrAny) -> Optional[str]: +# if DLT_METADATA_FIELD in item: +# meta: TEventDLTMeta = item[DLT_METADATA_FIELD] +# return meta.get("table_name", None) +# return None + + +# def with_retry(max_retries: int = 3, retry_sleep: float = 1.0) -> Callable[[Callable[_TFunParams, TBoundItem]], Callable[_TFunParams, TBoundItem]]: + +# def decorator(f: Callable[_TFunParams, TBoundItem]) -> Callable[_TFunParams, TBoundItem]: + +# def _wrap(*args: Any, **kwargs: Any) -> TBoundItem: +# attempts = 0 +# while True: +# try: +# return f(*args, **kwargs) +# except Exception as exc: +# if attempts == max_retries: +# raise +# attempts += 1 +# logger.warning(f"Exception {exc} in iterator, retrying {attempts} / {max_retries}") +# sleep(retry_sleep) + +# return _wrap + +# return decorator + + +TBoundItems = TypeVar("TBoundItems", bound=TDataItems) +TDeferred = Callable[[], TBoundItems] +TDeferredFunParams = ParamSpec("TDeferredFunParams") + + +def defer(f: Callable[TDeferredFunParams, TBoundItems]) -> Callable[TDeferredFunParams, TDeferred[TBoundItems]]: + + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> TDeferred[TBoundItems]: + def _curry() -> TBoundItems: + return f(*args, **kwargs) + return _curry + + return _wrap # type: ignore diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index 6582b526b7..bf0d4bff18 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -1,5 +1,85 @@ +from typing import Any, Type from dlt.common.exceptions import DltException class ExtractorException(DltException): pass + + +class DltSourceException(DltException): + pass + + +class DltResourceException(DltSourceException): + def __init__(self, resource_name: str, msg: str) -> None: + self.resource_name = resource_name + super().__init__(msg) + + +class PipeException(DltException): + pass + + +class CreatePipeException(PipeException): + pass + + +class PipeItemProcessingError(PipeException): + pass + + +# class InvalidIteratorException(PipelineException): +# def __init__(self, iterator: Any) -> None: +# super().__init__(f"Unsupported source iterator or iterable type: {type(iterator).__name__}") + + +# class InvalidItemException(PipelineException): +# def __init__(self, item: Any) -> None: +# super().__init__(f"Source yielded unsupported item type: {type(item).__name}. Only dictionaries, sequences and deferred items allowed.") + + +class ResourceNameMissing(DltResourceException): + def __init__(self) -> None: + super().__init__(None, """Resource name is missing. If you create a resource directly from data ie. from a list you must pass the name explicitly in `name` argument. + Please note that for resources created from functions or generators, the name is the function name by default.""") + + +class InvalidResourceDataType(DltResourceException): + def __init__(self, resource_name: str, item: Any, _typ: Type[Any], msg: str) -> None: + self.item = item + self._typ = _typ + super().__init__(resource_name, f"Cannot create resource {resource_name} from specified data. " + msg) + + +class InvalidResourceAsyncDataType(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, "Async iterators and generators are not valid resources. Please use standard iterators and generators that yield Awaitables instead (for example by yielding from async function without await") + + +class InvalidResourceBasicDataType(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, f"Resources cannot be strings or dictionaries but {_typ.__name__} was provided. Please pass your data in a list or as a function yielding items. If you want to process just one data item, enclose it in a list.") + + +class GeneratorFunctionNotAllowedAsParentResource(DltResourceException): + def __init__(self, resource_name: str, func_name: str) -> None: + self.func_name = func_name + super().__init__(resource_name, f"A parent resource {resource_name} of dependent resource {resource_name} is a function but must be a resource. Please decorate function") + + +class TableNameMissing(DltSourceException): + def __init__(self) -> None: + super().__init__("""Table name is missing in table template. Please provide a string or a function that takes a data item as an argument""") + + +class InconsistentTableTemplate(DltSourceException): + def __init__(self, reason: str) -> None: + msg = f"A set of table hints provided to the resource is inconsistent: {reason}" + super().__init__(msg) + + +class DataItemRequiredForDynamicTableHints(DltSourceException): + def __init__(self, resource_name: str) -> None: + self.resource_name = resource_name + super().__init__(f"""An instance of resource's data required to generate table schema in resource {resource_name}. + One of table hints for that resource (typically table name) is a function and hint is computed separately for each instance of data extracted from that resource.""") diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index c9a5476d50..e464a39031 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -2,14 +2,14 @@ from typing import ClassVar, List from dlt.common.utils import uniq_id -from dlt.common.source import TDirectDataItem, TDataItem +from dlt.common.typing import TDataItems, TDataItem from dlt.common.schema import utils, TSchemaUpdate from dlt.common.storages import NormalizeStorage, DataItemStorage from dlt.common.configuration.specs import NormalizeVolumeConfiguration from dlt.extract.pipe import PipeIterator -from dlt.extract.sources import DltResource, DltSource +from dlt.extract.source import DltResource, DltSource class ExtractorStorage(DataItemStorage, NormalizeStorage): @@ -52,7 +52,7 @@ def extract(source: DltSource, storage: ExtractorStorage, *, max_parallel_items: schema = source.schema extract_id = storage.create_extract_id() - def _write_item(table_name: str, item: TDirectDataItem) -> None: + def _write_item(table_name: str, item: TDataItems) -> None: # normalize table name before writing so the name match the name in schema # note: normalize function should be cached so there's almost no penalty on frequent calling # note: column schema is not required for jsonl writer used here diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index c70f1b684e..ae56b36a7f 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -8,27 +8,28 @@ from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec -from dlt.common.typing import TDataItem -from dlt.common.source import TDirectDataItem, TResolvableDataItem +from dlt.common.typing import TDataItem, TDataItems + +from dlt.extract.exceptions import CreatePipeException, PipeItemProcessingError +from dlt.extract.typing import TPipedDataItems if TYPE_CHECKING: - TItemFuture = Future[TDirectDataItem] + TItemFuture = Future[TDataItems] else: TItemFuture = Future -from dlt.common.exceptions import DltException from dlt.common.time import sleep class PipeItem(NamedTuple): - item: TDirectDataItem + item: TDataItems step: int pipe: "Pipe" class ResolvablePipeItem(NamedTuple): # mypy unable to handle recursive types, ResolvablePipeItem should take itself in "item" - item: Union[TResolvableDataItem, Iterator[TResolvableDataItem]] + item: Union[TPipedDataItems, Iterator[TPipedDataItems]] step: int pipe: "Pipe" @@ -40,18 +41,18 @@ class FuturePipeItem(NamedTuple): class SourcePipeItem(NamedTuple): - item: Union[Iterator[TResolvableDataItem], Iterator[ResolvablePipeItem]] + item: Union[Iterator[TPipedDataItems], Iterator[ResolvablePipeItem]] step: int pipe: "Pipe" # pipeline step may be iterator of data items or mapping function that returns data item or another iterator TPipeStep = Union[ - Iterable[TResolvableDataItem], - Iterator[TResolvableDataItem], - Callable[[TDirectDataItem], TResolvableDataItem], - Callable[[TDirectDataItem], Iterator[TResolvableDataItem]], - Callable[[TDirectDataItem], Iterator[ResolvablePipeItem]] + Iterable[TPipedDataItems], + Iterator[TPipedDataItems], + Callable[[TDataItems], TPipedDataItems], + Callable[[TDataItems], Iterator[TPipedDataItems]], + Callable[[TDataItems], Iterator[ResolvablePipeItem]] ] @@ -67,7 +68,7 @@ def add_pipe(self, pipe: "Pipe", step: int = -1) -> None: def has_pipe(self, pipe: "Pipe") -> bool: return pipe in [p[0] for p in self._pipes] - def __call__(self, item: TDirectDataItem) -> Iterator[ResolvablePipeItem]: + def __call__(self, item: TDataItems) -> Iterator[ResolvablePipeItem]: for i, (pipe, step) in enumerate(self._pipes): _it = item if i == 0 else deepcopy(item) # always start at the beginning @@ -78,7 +79,7 @@ class FilterItem: def __init__(self, filter_f: Callable[[TDataItem], bool]) -> None: self._filter_f = filter_f - def __call__(self, item: TDirectDataItem) -> Optional[TDirectDataItem]: + def __call__(self, item: TDataItems) -> Optional[TDataItems]: # item may be a list TDataItem or a single TDataItem if isinstance(item, list): item = [i for i in item if self._filter_f(i)] @@ -99,7 +100,7 @@ def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = No self.parent = parent @classmethod - def from_iterable(cls, name: str, gen: Union[Iterable[TResolvableDataItem], Iterator[TResolvableDataItem]], parent: "Pipe" = None) -> "Pipe": + def from_iterable(cls, name: str, gen: Union[Iterable[TPipedDataItems], Iterator[TPipedDataItems]], parent: "Pipe" = None) -> "Pipe": if isinstance(gen, Iterable): gen = iter(gen) return cls(name, [gen], parent=parent) @@ -186,6 +187,12 @@ def full_pipe(self) -> "Pipe": pipe.extend(self._steps) return Pipe(self.name, pipe) + def evaluate_head(self) -> None: + # if pipe head is callable then call it + if self.parent is None: + if callable(self.head): + self._steps[0] = self.head() + def __repr__(self) -> str: return f"Pipe {self.name} ({self._pipe_id}) at {id(self)}" @@ -198,7 +205,6 @@ class PipeIteratorConfiguration(BaseConfiguration): workers: int = 5 futures_poll_interval: float = 0.01 - def __init__(self, max_parallel_items: int, workers: int, futures_poll_interval: float) -> None: self.max_parallel_items = max_parallel_items self.workers = workers @@ -215,8 +221,8 @@ def __init__(self, max_parallel_items: int, workers: int, futures_poll_interval: def from_pipe(cls, pipe: Pipe, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": if pipe.parent: pipe = pipe.full_pipe() - # TODO: if pipe head is callable then call it now # head must be iterator + pipe.evaluate_head() assert isinstance(pipe.head, Iterator) # create extractor extract = cls(max_parallel_items, workers, futures_poll_interval) @@ -228,10 +234,12 @@ def from_pipe(cls, pipe: Pipe, *, max_parallel_items: int = 100, workers: int = @with_config(spec=PipeIteratorConfiguration) def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": extract = cls(max_parallel_items, workers, futures_poll_interval) + # TODO: consider removing cloning. pipe are single use and may be iterated only once, here we modify an immediately run # clone all pipes before iterating (recursively) as we will fork them and this add steps pipes = PipeIterator.clone_pipes(pipes) def _fork_pipeline(pipe: Pipe) -> None: + print(f"forking: {pipe.name}") if pipe.parent: # fork the parent pipe pipe.parent.fork(pipe) @@ -242,6 +250,7 @@ def _fork_pipeline(pipe: Pipe) -> None: _fork_pipeline(pipe.parent) else: # head of independent pipe must be iterator + pipe.evaluate_head() assert isinstance(pipe.head, Iterator) # add every head as source only once if not any(i.pipe == pipe for i in extract._sources): @@ -277,20 +286,22 @@ def __next__(self) -> PipeItem: sleep(self.futures_poll_interval) continue + + item = pipe_item.item # if item is iterator, then add it as a new source - if isinstance(pipe_item.item, Iterator): + if isinstance(item, Iterator): # print(f"adding iterable {item}") - self._sources.append(SourcePipeItem(pipe_item.item, pipe_item.step, pipe_item.pipe)) + self._sources.append(SourcePipeItem(item, pipe_item.step, pipe_item.pipe)) pipe_item = None continue - if isinstance(pipe_item.item, Awaitable) or callable(pipe_item.item): + if isinstance(item, Awaitable) or callable(item): # do we have a free slot or one of the slots is done? if len(self._futures) < self.max_parallel_items or self._next_future() >= 0: - if isinstance(pipe_item.item, Awaitable): - future = asyncio.run_coroutine_threadsafe(pipe_item.item, self._ensure_async_pool()) - else: - future = self._ensure_thread_pool().submit(pipe_item.item) + if isinstance(item, Awaitable): + future = asyncio.run_coroutine_threadsafe(item, self._ensure_async_pool()) + elif callable(item): + future = self._ensure_thread_pool().submit(item) # print(future) self._futures.append(FuturePipeItem(future, pipe_item.step, pipe_item.pipe)) # type: ignore # pipe item consumed for now, request a new one @@ -306,7 +317,7 @@ def __next__(self) -> PipeItem: # print(pipe_item) if pipe_item.step == len(pipe_item.pipe) - 1: # must be resolved - if isinstance(pipe_item.item, (Iterator, Awaitable)) or callable(pipe_item.pipe): + if isinstance(item, (Iterator, Awaitable)) or callable(pipe_item.pipe): raise PipeItemProcessingError("Pipe item not processed", pipe_item) # mypy not able to figure out that item was resolved return pipe_item # type: ignore @@ -314,8 +325,8 @@ def __next__(self) -> PipeItem: # advance to next step step = pipe_item.pipe[pipe_item.step + 1] assert callable(step) - item = step(pipe_item.item) - pipe_item = ResolvablePipeItem(item, pipe_item.step + 1, pipe_item.pipe) + next_item = step(item) + pipe_item = ResolvablePipeItem(next_item, pipe_item.step + 1, pipe_item.pipe) def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: @@ -434,16 +445,3 @@ def clone_pipes(pipes: Sequence[Pipe]) -> Sequence[Pipe]: clone = clone.parent return cloned_pipes - - -class PipeException(DltException): - pass - - -class CreatePipeException(PipeException): - pass - - -class PipeItemProcessingError(PipeException): - pass - diff --git a/dlt/extract/sources.py b/dlt/extract/source.py similarity index 75% rename from dlt/extract/sources.py rename to dlt/extract/source.py index fb29e1ace4..fe99c3a1cb 100644 --- a/dlt/extract/sources.py +++ b/dlt/extract/source.py @@ -2,27 +2,18 @@ from copy import deepcopy import inspect from collections.abc import Mapping as C_Mapping -from typing import AsyncIterable, AsyncIterator, Coroutine, Dict, Generator, Iterable, Iterator, List, NamedTuple, Set, TypedDict, Union, Awaitable, Callable, Sequence, TypeVar, cast, Optional, Any +from typing import AsyncIterable, AsyncIterator, Iterable, Iterator, List, Set, Sequence, Union, cast, Any -from dlt.common.exceptions import DltException -from dlt.common.typing import TDataItem -from dlt.common.source import TFunHintTemplate, TDirectDataItem, TTableHintTemplate from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.schema.typing import TColumnSchema, TPartialTableSchema, TTableSchema, TTableSchemaColumns, TWriteDisposition +from dlt.common.schema.typing import TPartialTableSchema, TTableSchemaColumns, TWriteDisposition +from dlt.common.typing import TDataItem, TDataItems from dlt.common.configuration.container import Container from dlt.common.pipeline import PipelineContext -from dlt.extract.pipe import FilterItem, Pipe, CreatePipeException, PipeIterator - - -class TTableSchemaTemplate(TypedDict, total=False): - name: TTableHintTemplate[str] - description: TTableHintTemplate[str] - write_disposition: TTableHintTemplate[TWriteDisposition] - # table_sealed: Optional[bool] - parent: TTableHintTemplate[str] - columns: TTableHintTemplate[TTableSchemaColumns] +from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate, TTableSchemaTemplate +from dlt.extract.pipe import FilterItem, Pipe, PipeIterator +from dlt.extract.exceptions import CreatePipeException, DataItemRequiredForDynamicTableHints, GeneratorFunctionNotAllowedAsParentResource, InconsistentTableTemplate, InvalidResourceAsyncDataType, InvalidResourceBasicDataType, ResourceNameMissing, TableNameMissing class DltResourceSchema: @@ -102,7 +93,7 @@ def new_table_template( columns: TTableHintTemplate[TTableSchemaColumns] = None, ) -> TTableSchemaTemplate: if not table_name: - raise InvalidTableSchemaTemplate("Table template name must be a string or function taking TDataItem") + raise TableNameMissing() # create a table schema template where hints can be functions taking TDataItem if isinstance(columns, C_Mapping): # new_table accepts a sequence @@ -111,10 +102,11 @@ def new_table_template( new_template: TTableSchemaTemplate = new_table(table_name, parent_table_name, write_disposition=write_disposition, columns=columns) # type: ignore # if any of the hints is a function then name must be as well if any(callable(v) for k, v in new_template.items() if k != "name") and not callable(table_name): - raise InvalidTableSchemaTemplate("Table name must be a function if any other table hint is a function") + raise InconsistentTableTemplate("Table name must be a function if any other table hint is a function") return new_template -class DltResource(Iterable[TDirectDataItem], DltResourceSchema): + +class DltResource(Iterable[TDataItems], DltResourceSchema): def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate, selected: bool): self.name = pipe.name self.selected = selected @@ -122,15 +114,7 @@ def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate, sele super().__init__(self.name, table_schema_template) @classmethod - def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSchemaTemplate = None, selected: bool = True, depends_on: "DltResource" = None) -> "DltResource": - # call functions assuming that they do not take any parameters, typically they are generator functions - if callable(data): - # use inspect.isgeneratorfunction to see if this is generator or not - # if it is then call it, if not then keep the callable assuming that it will return iterable/iterator - # if inspect.isgeneratorfunction(data): - # data = data() - # else: - data = data() + def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSchemaTemplate = None, selected: bool = True, depends_on: Union["DltResource", Pipe] = None) -> "DltResource": if isinstance(data, DltResource): return data @@ -138,34 +122,58 @@ def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSch if isinstance(data, Pipe): return cls(data, table_schema_template, selected) + if callable(data): + name = name or data.__name__ + # function must be a generator + if not inspect.isgeneratorfunction(inspect.unwrap(data)): + raise ResourceFunctionNotAGenerator(name) + + # if generator, take name from it + if inspect.isgenerator(data): + name = name or data.__name__ + + # name is mandatory + if not name: + raise ResourceNameMissing() + # several iterable types are not allowed and must be excluded right away - if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): - raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) + if isinstance(data, (AsyncIterator, AsyncIterable)): + raise InvalidResourceAsyncDataType(name, data, type(data)) + if isinstance(data, (str, dict)): + raise InvalidResourceBasicDataType(name, data, type(data)) # check if depends_on is a valid resource parent_pipe: Pipe = None if depends_on: - if not isinstance(depends_on, DltResource): + if not callable(data): + raise DependentResourceMustBeAGeneratorFunction() + else: + pass + # TODO: check sig if takes just one argument + # if sig_valid(): + # raise DependentResourceMustTakeDataItemArgument() + if isinstance(depends_on, Pipe): + parent_pipe = depends_on + elif isinstance(depends_on, DltResource): + parent_pipe = depends_on._pipe + else: # if this is generator function provide nicer exception - if inspect.isgeneratorfunction(inspect.unwrap(depends_on)): - raise ParentResourceIsGeneratorFunction() + if callable(depends_on): + raise GeneratorFunctionNotAllowedAsParentResource(depends_on.__name__) else: raise ParentNotAResource() - parent_pipe = depends_on._pipe - # create resource from iterator or iterable + + # create resource from iterator, iterable or generator function if isinstance(data, (Iterable, Iterator)): - if inspect.isgenerator(data): - name = name or data.__name__ - else: - name = name or None - if not name: - raise ResourceNameRequired("The DltResource name was not provided or could not be inferred.") pipe = Pipe.from_iterable(name, data, parent=parent_pipe) + elif callable(data): + pipe = Pipe(name, [data], parent_pipe) + if pipe: return cls(pipe, table_schema_template, selected) - - # some other data type that is not supported - raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) + else: + # some other data type that is not supported + raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) def select(self, *table_names: Iterable[str]) -> "DltResource": @@ -188,14 +196,20 @@ def flat_map(self) -> None: def filter(self) -> None: raise NotImplementedError() - def __iter__(self) -> Iterator[TDirectDataItem]: + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # make resource callable to support parametrized resources which are functions taking arguments + _data = self._pipe.head(*args, **kwargs) + # create new resource from extracted data + return DltResource.from_data(_data, self.name, self._table_schema_template, self.selected, self._pipe.parent) + + def __iter__(self) -> Iterator[TDataItems]: return map(lambda item: item.item, PipeIterator.from_pipe(self._pipe)) def __repr__(self) -> str: return f"DltResource {self.name} ({self._pipe._pipe_id}) at {id(self)}" -class DltSource(Iterable[TDirectDataItem]): +class DltSource(Iterable[TDataItems]): def __init__(self, schema: Schema, resources: Sequence[DltResource] = None) -> None: self.name = schema.name self._schema = schema @@ -266,22 +280,8 @@ def select(self, *resource_names: str) -> "DltSource": def run(self, destination: Any) -> Any: return Container()[PipelineContext].pipeline().run(source=self, destination=destination) - def __iter__(self) -> Iterator[TDirectDataItem]: + def __iter__(self) -> Iterator[TDataItems]: return map(lambda item: item.item, PipeIterator.from_pipes(self.pipes)) def __repr__(self) -> str: return f"DltSource {self.name} at {id(self)}" - - -class DltSourceException(DltException): - pass - - -class DataItemRequiredForDynamicTableHints(DltException): - def __init__(self, resource_name: str) -> None: - self.resource_name = resource_name - super().__init__(f"Instance of Data Item required to generate table schema in resource {resource_name}") - - - -# class diff --git a/dlt/extract/typing.py b/dlt/extract/typing.py new file mode 100644 index 0000000000..7c79453eb3 --- /dev/null +++ b/dlt/extract/typing.py @@ -0,0 +1,22 @@ +from typing import Callable, TypedDict, TypeVar, Union, List, Awaitable + +from dlt.common.typing import TDataItem, TDataItems +from dlt.common.schema.typing import TTableSchemaColumns, TWriteDisposition + + +TDeferredDataItems = Callable[[], TDataItems] +TAwaitableDataItems = Awaitable[TDataItems] +TPipedDataItems = Union[TDataItems, TDeferredDataItems, TAwaitableDataItems] + +TDynHintType = TypeVar("TDynHintType") +TFunHintTemplate = Callable[[TDataItem], TDynHintType] +TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]] + + +class TTableSchemaTemplate(TypedDict, total=False): + name: TTableHintTemplate[str] + description: TTableHintTemplate[str] + write_disposition: TTableHintTemplate[TWriteDisposition] + # table_sealed: Optional[bool] + parent: TTableHintTemplate[str] + columns: TTableHintTemplate[TTableSchemaColumns] From 0a9691a216d9eb1697f53f0d2045e9b06b585c82 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 30 Oct 2022 01:32:18 +0200 Subject: [PATCH 60/66] adds self importing via name or module name for destination reference --- dlt/common/destination.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/dlt/common/destination.py b/dlt/common/destination.py index e3cda90278..328a332450 100644 --- a/dlt/common/destination.py +++ b/dlt/common/destination.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod from importlib import import_module +from nis import cat from types import TracebackType -from typing import ClassVar, List, Optional, Literal, Type, Protocol, Union, TYPE_CHECKING +from typing import ClassVar, List, Optional, Literal, Type, Protocol, Union, TYPE_CHECKING, cast from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchema @@ -141,14 +142,18 @@ def client(self, schema: Schema, initial_config: DestinationClientConfiguration def spec(self) -> Type[DestinationClientConfiguration]: ... - -def resolve_destination_reference(destination: Union[None, str, DestinationReference]) -> DestinationReference: - if destination is None: - return None - - if isinstance(destination, str): - # TODO: figure out if this is full module path name or name of one of the known destinations - # if destination is a str, get destination reference by dynamically importing module from known location - return import_module(f"dlt.load.{destination}") - - return destination \ No newline at end of file + @staticmethod + def from_name(destination: Union[None, str, "DestinationReference"]) -> "DestinationReference": + if destination is None: + return None + + # if destination is a str, get destination reference by dynamically importing module + if isinstance(destination, str): + if "." in destination: + # this is full module name + return cast(DestinationReference, import_module(destination)) + else: + # from known location + return cast(DestinationReference, import_module(f"dlt.load.{destination}")) + + return destination From 1f7bf39d8ef61b7ce9cd348475d041cc57183dda Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 30 Oct 2022 01:32:37 +0200 Subject: [PATCH 61/66] fixes config exceptions names --- dlt/common/configuration/__init__.py | 2 +- dlt/common/configuration/exceptions.py | 29 +++++++++++++++---- dlt/common/configuration/providers/toml.py | 1 - dlt/common/configuration/resolve.py | 11 ++++--- .../specs/config_providers_context.py | 14 +++++---- dlt/common/data_writers/buffered.py | 5 ++-- dlt/common/exceptions.py | 6 ++-- dlt/common/storages/data_item_storage.py | 4 +-- dlt/common/storages/file_storage.py | 2 +- dlt/helpers/streamlit.py | 1 - dlt/load/bigquery/configuration.py | 2 +- .../configuration/test_configuration.py | 22 +++++++------- .../configuration/test_environ_provider.py | 6 ++-- tests/common/configuration/test_namespaces.py | 8 ++--- .../configuration/test_toml_provider.py | 14 ++++----- 15 files changed, 72 insertions(+), 55 deletions(-) diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index ccc176c807..8a2d92346b 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -3,4 +3,4 @@ from .inject import with_config, last_config, get_fun_spec from .exceptions import ( # noqa: F401 - ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, ConfigIntegrityException, ConfigFileNotFoundException) + ConfigFieldMissingException, ConfigValueCannotBeCoercedException, ConfigIntegrityException, ConfigFileNotFoundException) diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index ee70ec9c62..c905820e06 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -15,13 +15,24 @@ def __init__(self, msg: str) -> None: super().__init__(msg) + +class ContainerException(ConfigurationException): + """base exception for all exceptions related to injectable container""" + pass + + +class ConfigProviderException(ConfigurationException): + """base exceptions for all exceptions raised by config providers""" + pass + + class ConfigurationWrongTypeException(ConfigurationException): def __init__(self, _typ: type) -> None: super().__init__(f"Invalid configuration instance type {_typ}. Configuration instances must derive from BaseConfiguration.") -class ConfigEntryMissingException(ConfigurationException): - """thrown when not all required config elements are present""" +class ConfigFieldMissingException(ConfigurationException): + """thrown when not all required config fields are present""" def __init__(self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]]) -> None: self.traces = traces @@ -35,8 +46,8 @@ def __init__(self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]]) super().__init__(msg) -class ConfigEnvValueCannotBeCoercedException(ConfigurationException): - """thrown when value from ENV cannot be coerced to hinted type""" +class ConfigValueCannotBeCoercedException(ConfigurationException): + """thrown when value returned by config provider cannot be coerced to hinted type""" def __init__(self, field_name: str, field_value: Any, hint: type) -> None: self.field_name = field_name @@ -94,7 +105,7 @@ def __init__(self, spec: Type[Any], initial_value_type: Type[Any]) -> None: super().__init__(f"Initial value of type {initial_value_type} is not valid for {spec.__name__}") -class ContainerInjectableContextMangled(ConfigurationException): +class ContainerInjectableContextMangled(ContainerException): def __init__(self, spec: Type[Any], existing_config: Any, expected_config: Any) -> None: self.spec = spec self.existing_config = existing_config @@ -102,7 +113,13 @@ def __init__(self, spec: Type[Any], existing_config: Any, expected_config: Any) super().__init__(f"When restoring context {spec.__name__}, instance {expected_config} was expected, instead instance {existing_config} was found.") -class ContextDefaultCannotBeCreated(ConfigurationException, KeyError): +class ContextDefaultCannotBeCreated(ContainerException, KeyError): def __init__(self, spec: Type[Any]) -> None: self.spec = spec super().__init__(f"Container cannot create the default value of context {spec.__name__}.") + + +class DuplicateConfigProviderException(ConfigProviderException): + def __init__(self, provider_name: str) -> None: + self.provider_name = provider_name + super().__init__(f"Provider with name {provider_name} already present in ConfigProvidersContext") diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index c4916d9b3a..0f49691cbc 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -1,5 +1,4 @@ import os -import abc import tomlkit from typing import Any, Optional, Tuple, Type diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 5be97e4e92..e3e6b8843e 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -12,7 +12,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.providers.container import ContextProvider -from dlt.common.configuration.exceptions import (LookupTrace, ConfigEntryMissingException, ConfigurationWrongTypeException, ConfigEnvValueCannotBeCoercedException, ValueNotSecretException, InvalidInitialValue) +from dlt.common.configuration.exceptions import (LookupTrace, ConfigFieldMissingException, ConfigurationWrongTypeException, ConfigValueCannotBeCoercedException, ValueNotSecretException, InvalidInitialValue) CHECK_INTEGRITY_F: str = "check_integrity" TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) @@ -47,10 +47,10 @@ def deserialize_value(key: str, value: Any, hint: Type[Any]) -> Any: if value_dt != hint_dt: value = coerce_type(hint_dt, value_dt, value) return value - except ConfigEnvValueCannotBeCoercedException: + except ConfigValueCannotBeCoercedException: raise except Exception as exc: - raise ConfigEnvValueCannotBeCoercedException(key, value, hint) from exc + raise ConfigValueCannotBeCoercedException(key, value, hint) from exc def serialize_value(value: Any) -> Any: @@ -91,7 +91,6 @@ def _resolve_configuration( initial_value: Any, accept_partial: bool ) -> TConfiguration: - # print(f"RESOLVING: {locals()}") # do not resolve twice if config.is_resolved(): return config @@ -134,7 +133,7 @@ def _resolve_configuration( _check_configuration_integrity(config) # full configuration was resolved config.__is_resolved__ = True - except ConfigEntryMissingException as cm_ex: + except ConfigFieldMissingException as cm_ex: if not accept_partial: raise else: @@ -211,7 +210,7 @@ def _resolve_config_fields( # set resolved value in config setattr(config, key, current_value) if unresolved_fields: - raise ConfigEntryMissingException(type(config).__name__, unresolved_fields) + raise ConfigFieldMissingException(type(config).__name__, unresolved_fields) def _log_traces(config: BaseConfiguration, key: str, hint: Type[Any], value: Any, traces: Sequence[LookupTrace]) -> None: diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index f6576b8367..02cd4d6d37 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -2,11 +2,12 @@ from typing import List +from dlt.common.configuration.exceptions import DuplicateConfigProviderException from dlt.common.configuration.providers import Provider from dlt.common.configuration.providers.environ import EnvironProvider from dlt.common.configuration.providers.container import ContextProvider from dlt.common.configuration.providers.toml import SecretsTomlProvider, ConfigTomlProvider -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, ContainerInjectableContext, configspec +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec @configspec @@ -34,11 +35,12 @@ def __contains__(self, name: object) -> bool: def add_provider(self, provider: Provider) -> None: if provider.name in self: - raise DuplicateProviderException(provider.name) + raise DuplicateConfigProviderException(provider.name) self.providers.append(provider) -@configspec -class ConfigProvidersConfiguration(BaseConfiguration): - with_aws_secrets: bool = False - with_google_secrets: bool = False +# TODO: implement ConfigProvidersConfiguration and +# @configspec +# class ConfigProvidersConfiguration(BaseConfiguration): +# with_aws_secrets: bool = False +# with_google_secrets: bool = False diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index bbb6a380b2..1423fa3abe 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,8 +1,7 @@ from typing import List, IO, Any, Optional from dlt.common.utils import uniq_id -from dlt.common.typing import TDataItem -from dlt.common.source import TDirectDataItem +from dlt.common.typing import TDataItem, TDataItems from dlt.common.data_writers import TLoaderFileFormat from dlt.common.data_writers.exceptions import BufferedDataWriterClosed, InvalidFileNameTemplateException from dlt.common.data_writers.writers import DataWriter @@ -35,7 +34,7 @@ def __init__(self, file_format: TLoaderFileFormat, file_name_template: str, *, b except TypeError: raise InvalidFileNameTemplateException(file_name_template) - def write_data_item(self, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: + def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> None: self._ensure_open() # rotate file if columns changed and writer does not allow for that # as the only allowed change is to add new column (no updates/deletes), we detect the change by comparing lengths diff --git a/dlt/common/exceptions.py b/dlt/common/exceptions.py index 21e44450d0..0d88d8d5f4 100644 --- a/dlt/common/exceptions.py +++ b/dlt/common/exceptions.py @@ -84,6 +84,8 @@ def __init__(self, msg: str, path: str, field: str = None, value: Any = None) -> class ArgumentsOverloadException(DltException): - def __init__(self, msg: str, *args: str) -> None: - self.args = args + def __init__(self, msg: str, func_name: str, *args: str) -> None: + self.func_name = func_name + msg = f"Arguments combination not allowed when calling function {func_name}: {msg}" + msg = "\n".join((msg, *args)) super().__init__(msg) diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index d3304d1418..27a4a688b1 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -3,7 +3,7 @@ from dlt.common import logger from dlt.common.schema import TTableSchemaColumns -from dlt.common.source import TDirectDataItem +from dlt.common.typing import TDataItems from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter @@ -13,7 +13,7 @@ def __init__(self, load_file_type: TLoaderFileFormat, *args: Any) -> None: self.buffered_writers: Dict[str, BufferedDataWriter] = {} super().__init__(*args) - def write_data_item(self, load_id: str, schema_name: str, table_name: str, item: TDirectDataItem, columns: TTableSchemaColumns) -> None: + def write_data_item(self, load_id: str, schema_name: str, table_name: str, item: TDataItems, columns: TTableSchemaColumns) -> None: # unique writer id writer_id = f"{load_id}.{schema_name}.{table_name}" writer = self.buffered_writers.get(writer_id, None) diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index 400d162426..8947515a42 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -129,7 +129,7 @@ def to_relative_path(self, path: str) -> str: return os.path.relpath(path, start=self.storage_path) def make_full_path(self, path: str) -> str: - # try to make a relative path is paths are absolute or overlapping + # try to make a relative path if paths are absolute or overlapping try: path = self.to_relative_path(path) except ValueError: diff --git a/dlt/helpers/streamlit.py b/dlt/helpers/streamlit.py index a65d0b2f0d..f063ff2273 100644 --- a/dlt/helpers/streamlit.py +++ b/dlt/helpers/streamlit.py @@ -1,4 +1,3 @@ - import os import tomlkit from tomlkit.container import Container as TomlContainer diff --git a/dlt/load/bigquery/configuration.py b/dlt/load/bigquery/configuration.py index 325ff6430f..865dd0db35 100644 --- a/dlt/load/bigquery/configuration.py +++ b/dlt/load/bigquery/configuration.py @@ -4,7 +4,7 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import GcpClientCredentials -from dlt.common.configuration.exceptions import ConfigEntryMissingException +from dlt.common.configuration.exceptions import ConfigFieldMissingException from dlt.common.destination import DestinationClientDwhConfiguration diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index a442e49366..a9fc2a7e9e 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -6,7 +6,7 @@ from dlt.common.utils import custom_environ from dlt.common.typing import TSecretValue, extract_inner_type from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, InvalidInitialValue, LookupTrace, ValueNotSecretException -from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, resolve +from dlt.common.configuration import configspec, ConfigFieldMissingException, ConfigValueCannotBeCoercedException, resolve from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration from dlt.common.configuration.specs.base_configuration import is_valid_hint from dlt.common.configuration.providers import environ as environ_provider, toml @@ -331,7 +331,7 @@ class MultiConfiguration(MockProdConfiguration, ConfigurationWithOptionalTypes, def test_raises_on_unresolved_field(environment: Any) -> None: # via make configuration - with pytest.raises(ConfigEntryMissingException) as cf_missing_exc: + with pytest.raises(ConfigFieldMissingException) as cf_missing_exc: resolve.resolve_configuration(WrongConfiguration()) assert cf_missing_exc.value.spec_name == "WrongConfiguration" assert "NoneConfigVar" in cf_missing_exc.value.traces @@ -345,7 +345,7 @@ def test_raises_on_unresolved_field(environment: Any) -> None: def test_raises_on_many_unresolved_fields(environment: Any) -> None: # via make configuration - with pytest.raises(ConfigEntryMissingException) as cf_missing_exc: + with pytest.raises(ConfigFieldMissingException) as cf_missing_exc: resolve.resolve_configuration(CoercionTestConfiguration()) assert cf_missing_exc.value.spec_name == "CoercionTestConfiguration" # get all fields that must be set @@ -425,7 +425,7 @@ def test_invalid_coercions(environment: Any) -> None: for key, value in INVALID_COERCIONS.items(): try: resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) - except ConfigEnvValueCannotBeCoercedException as coerc_exc: + except ConfigValueCannotBeCoercedException as coerc_exc: # must fail exactly on expected value if coerc_exc.field_name != key: raise @@ -511,7 +511,7 @@ def test_accept_partial(environment: Any) -> None: def test_coercion_rules() -> None: - with pytest.raises(ConfigEnvValueCannotBeCoercedException): + with pytest.raises(ConfigValueCannotBeCoercedException): coerce_single_value("key", "some string", int) assert coerce_single_value("key", "some string", str) == "some string" # Optional[str] has type object, mypy will never work properly... @@ -528,9 +528,9 @@ def test_coercion_rules() -> None: assert coerce_single_value("key", "234", LongInteger) == 234 assert coerce_single_value("key", "234", Optional[LongInteger]) == 234 # type: ignore # this coercion should fail - with pytest.raises(ConfigEnvValueCannotBeCoercedException): + with pytest.raises(ConfigValueCannotBeCoercedException): coerce_single_value("key", "some string", LongInteger) - with pytest.raises(ConfigEnvValueCannotBeCoercedException): + with pytest.raises(ConfigValueCannotBeCoercedException): coerce_single_value("key", "some string", Optional[LongInteger]) # type: ignore @@ -544,10 +544,10 @@ def test_is_valid_hint() -> None: # in case of generics, origin will be used and args are not checked assert is_valid_hint(MutableMapping[TSecretValue, Any]) is True # this is valid (args not checked) - assert is_valid_hint(MutableMapping[TSecretValue, ConfigEnvValueCannotBeCoercedException]) is True + assert is_valid_hint(MutableMapping[TSecretValue, ConfigValueCannotBeCoercedException]) is True assert is_valid_hint(Wei) is True # any class type, except deriving from BaseConfiguration is wrong type - assert is_valid_hint(ConfigEntryMissingException) is False + assert is_valid_hint(ConfigFieldMissingException) is False def test_configspec_auto_base_config_derivation() -> None: @@ -622,10 +622,10 @@ def test_do_not_resolve_embedded(environment: Any) -> None: def test_last_resolve_exception(environment: Any) -> None: # partial will set the ConfigEntryMissingException c = resolve.resolve_configuration(EmbeddedConfiguration(), accept_partial=True) - assert isinstance(c.__exception__, ConfigEntryMissingException) + assert isinstance(c.__exception__, ConfigFieldMissingException) # missing keys c = SecretConfiguration() - with pytest.raises(ConfigEntryMissingException) as py_ex: + with pytest.raises(ConfigFieldMissingException) as py_ex: resolve.resolve_configuration(c) assert c.__exception__ is py_ex.value # but if ran again exception is cleared diff --git a/tests/common/configuration/test_environ_provider.py b/tests/common/configuration/test_environ_provider.py index a7392754d1..9f7a2c2885 100644 --- a/tests/common/configuration/test_environ_provider.py +++ b/tests/common/configuration/test_environ_provider.py @@ -2,7 +2,7 @@ from typing import Any from dlt.common.typing import TSecretValue -from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, resolve +from dlt.common.configuration import configspec, ConfigFieldMissingException, ConfigFileNotFoundException, resolve from dlt.common.configuration.specs import RunConfiguration from dlt.common.configuration.providers import environ as environ_provider @@ -50,7 +50,7 @@ def test_resolves_from_environ_with_coercion(environment: Any) -> None: def test_secret(environment: Any) -> None: - with pytest.raises(ConfigEntryMissingException): + with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(SecretConfiguration()) environment['SECRET_VALUE'] = "1" C = resolve.resolve_configuration(SecretConfiguration()) @@ -68,7 +68,7 @@ def test_secret(environment: Any) -> None: # set some weird path, no secret file at all del environment['SECRET_VALUE'] environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" - with pytest.raises(ConfigEntryMissingException): + with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(SecretConfiguration()) # set env which is a fallback for secret not as file diff --git a/tests/common/configuration/test_namespaces.py b/tests/common/configuration/test_namespaces.py index 24d28568e2..2e39721817 100644 --- a/tests/common/configuration/test_namespaces.py +++ b/tests/common/configuration/test_namespaces.py @@ -2,7 +2,7 @@ from typing import Any, Optional from dlt.common.configuration.container import Container -from dlt.common.configuration import configspec, ConfigEntryMissingException, resolve, inject_namespace +from dlt.common.configuration import configspec, ConfigFieldMissingException, resolve, inject_namespace from dlt.common.configuration.specs import BaseConfiguration, ConfigNamespacesContext # from dlt.common.configuration.providers import environ as environ_provider from dlt.common.configuration.exceptions import LookupTrace @@ -27,7 +27,7 @@ class EmbeddedWithNamespacedConfiguration(BaseConfiguration): def test_namespaced_configuration(environment: Any) -> None: - with pytest.raises(ConfigEntryMissingException) as exc_val: + with pytest.raises(ConfigFieldMissingException) as exc_val: resolve.resolve_configuration(NamespacedConfiguration()) assert list(exc_val.value.traces.keys()) == ["password"] @@ -46,7 +46,7 @@ def test_namespaced_configuration(environment: Any) -> None: # env var must be prefixed environment["PASSWORD"] = "PASS" - with pytest.raises(ConfigEntryMissingException) as exc_val: + with pytest.raises(ConfigFieldMissingException) as exc_val: resolve.resolve_configuration(NamespacedConfiguration()) environment["DLT_TEST__PASSWORD"] = "PASS" C = resolve.resolve_configuration(NamespacedConfiguration()) @@ -177,7 +177,7 @@ def test_namespace_with_pipeline_name(mock_provider: MockProvider) -> None: mock_provider.return_value_on = () mock_provider.reset_stats() # () will never be searched - with pytest.raises(ConfigEntryMissingException): + with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(NamespacedConfiguration()) mock_provider.return_value_on = ("DLT_TEST",) mock_provider.reset_stats() diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index 6a8128154c..304c3b59d1 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -4,7 +4,7 @@ from dlt.common import pendulum -from dlt.common.configuration import configspec, ConfigEntryMissingException, ConfigFileNotFoundException, resolve +from dlt.common.configuration import configspec, ConfigFieldMissingException, ConfigFileNotFoundException, resolve from dlt.common.configuration.container import Container from dlt.common.configuration.inject import with_config from dlt.common.configuration.exceptions import LookupTrace @@ -40,7 +40,7 @@ def providers() -> Iterator[ConfigProvidersContext]: def test_secrets_from_toml_secrets() -> None: - with pytest.raises(ConfigEntryMissingException) as py_ex: + with pytest.raises(ConfigFieldMissingException) as py_ex: resolve.resolve_configuration(SecretConfiguration()) # only two traces because TSecretValue won't be checked in config.toml provider @@ -49,7 +49,7 @@ def test_secrets_from_toml_secrets() -> None: assert traces[0] == LookupTrace("Environment Variables", [], "SECRET_VALUE", None) assert traces[1] == LookupTrace("Pipeline secrets.toml", [], "secret_value", None) - with pytest.raises(ConfigEntryMissingException) as py_ex: + with pytest.raises(ConfigFieldMissingException) as py_ex: resolve.resolve_configuration(WithCredentialsConfiguration()) @@ -116,13 +116,13 @@ def test_secrets_toml_credentials(providers: ConfigProvidersContext) -> None: c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination",)) assert c.project_id.endswith("destination.credentials") # there's "credentials" key but does not contain valid gcp credentials - with pytest.raises(ConfigEntryMissingException): + with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(GcpClientCredentials()) # also try postgres credentials c = resolve.resolve_configuration(PostgresCredentials(), namespaces=("destination", "redshift")) assert c.dbname == "destination.redshift.credentials" # bigquery credentials do not match redshift credentials - with pytest.raises(ConfigEntryMissingException): + with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(PostgresCredentials(), namespaces=("destination", "bigquery")) @@ -138,7 +138,7 @@ def test_secrets_toml_embedded_credentials(providers: ConfigProvidersContext) -> c = EmbeddedWithGcpCredentials() # create embedded config that will be passed as initial c.credentials = GcpClientCredentials() - with pytest.raises(ConfigEntryMissingException) as py_ex: + with pytest.raises(ConfigFieldMissingException) as py_ex: resolve.resolve_configuration(c, namespaces=("middleware", "storage")) # so we can read partially filled configuration here assert c.credentials.project_id.endswith("-credentials") @@ -152,7 +152,7 @@ def test_secrets_toml_embedded_credentials(providers: ConfigProvidersContext) -> c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination",)) assert c.project_id.endswith("destination.credentials") # there's "credentials" key but does not contain valid gcp credentials - with pytest.raises(ConfigEntryMissingException): + with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(GcpClientCredentials()) From 371a0581b9fdc555061d5658dadb81954406568d Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 31 Oct 2022 20:22:42 +0100 Subject: [PATCH 62/66] adds optional embedded namespaces extension in config resolve + tests --- dlt/common/configuration/resolve.py | 47 ++++++++++--------- .../configuration/specs/base_configuration.py | 10 ++-- dlt/load/configuration.py | 2 +- dlt/normalize/configuration.py | 6 +-- dlt/pipeline/configuration.py | 4 +- .../configuration/test_configuration.py | 47 +++++++++++-------- tests/common/configuration/test_namespaces.py | 44 +++++++++++++++-- 7 files changed, 104 insertions(+), 56 deletions(-) diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index e3e6b8843e..0e9348fc0e 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -108,14 +108,15 @@ def _resolve_configuration( resolved_initial: Any = None if config.__namespace__ or embedded_namespaces: cf_n, emb_ns = _apply_embedded_namespaces_to_config_namespace(config.__namespace__, embedded_namespaces) - resolved_initial, traces = _resolve_single_field(cf_n, type(config), None, explicit_namespaces, emb_ns) - _log_traces(config, cf_n, type(config), resolved_initial, traces) - # initial values cannot be dictionaries - if not isinstance(resolved_initial, C_Mapping): - initial_value = resolved_initial or initial_value - # if this is injectable context then return it immediately - if isinstance(resolved_initial, ContainerInjectableContext): - return resolved_initial # type: ignore + if cf_n: + resolved_initial, traces = _resolve_single_field(cf_n, type(config), None, explicit_namespaces, emb_ns) + _log_traces(config, cf_n, type(config), resolved_initial, traces) + # initial values cannot be dictionaries + if not isinstance(resolved_initial, C_Mapping): + initial_value = resolved_initial or initial_value + # if this is injectable context then return it immediately + if isinstance(resolved_initial, ContainerInjectableContext): + return resolved_initial # type: ignore try: try: # use initial value to set config values @@ -147,20 +148,6 @@ def _resolve_configuration( return config -def _apply_embedded_namespaces_to_config_namespace(config_namespace: str, embedded_namespaces: Tuple[str, ...]) -> Tuple[str, Tuple[str, ...]]: - # for the configurations that have __namespace__ (config_namespace) defined and are embedded in other configurations, - # the innermost embedded namespace replaces config_namespace - if embedded_namespaces: - config_namespace = embedded_namespaces[-1] - embedded_namespaces = embedded_namespaces[:-1] - # if config_namespace: - return config_namespace, embedded_namespaces - - -def _is_secret_hint(hint: Type[Any]) -> bool: - return hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration)) - - def _resolve_config_fields( config: BaseConfiguration, explicit_namespaces: Tuple[str, ...], @@ -311,3 +298,19 @@ def look_namespaces(pipeline_name: str = None) -> Any: value = look_namespaces() return value, traces + + +def _apply_embedded_namespaces_to_config_namespace(config_namespace: str, embedded_namespaces: Tuple[str, ...]) -> Tuple[str, Tuple[str, ...]]: + # for the configurations that have __namespace__ (config_namespace) defined and are embedded in other configurations, + # the innermost embedded namespace replaces config_namespace + if embedded_namespaces: + # do not add key to embedded namespaces if it starts with _, those namespaces must be ignored + if not embedded_namespaces[-1].startswith("_"): + config_namespace = embedded_namespaces[-1] + embedded_namespaces = embedded_namespaces[:-1] + + return config_namespace, embedded_namespaces + + +def _is_secret_hint(hint: Type[Any]) -> bool: + return hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration)) diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 36d7c474d7..9e98958778 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -81,9 +81,6 @@ class BaseConfiguration(MutableMapping[str, Any]): # holds the exception that prevented the full resolution __exception__: Exception = dataclasses.field(default = None, init=False, repr=False) - def __init__(self) -> None: - self.__ignore_set_unknown_keys = False - def from_native_representation(self, native_value: Any) -> None: """Initialize the configuration fields by parsing the `initial_value` which should be a native representation of the configuration or credentials, for example database connection string or JSON serialized GCP service credentials file. @@ -136,7 +133,12 @@ def __setitem__(self, __key: str, __value: Any) -> None: if self.__has_attr(__key): setattr(self, __key, __value) else: - if not self.__ignore_set_unknown_keys: + try: + if not self.__ignore_set_unknown_keys: + # assert getattr(self, "__ignore_set_unknown_keys") is not None + raise KeyError(__key) + except AttributeError: + # __ignore_set_unknown_keys attribute may not be present at the moment of checking, __init__ of BaseConfiguration is not typically called raise KeyError(__key) def __delitem__(self, __key: str) -> None: diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index bc13eb15ab..42d8cb1209 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -10,7 +10,7 @@ class LoaderConfiguration(PoolRunnerConfiguration): workers: int = 20 # how many parallel loads can be executed pool_type: TPoolType = "thread" # mostly i/o (upload) so may be thread pool always_wipe_storage: bool = False # removes all data in the storage - load_storage_config: LoadVolumeConfiguration = None + _load_storage_config: LoadVolumeConfiguration = None if TYPE_CHECKING: def __init__( diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index 16f85dc880..6520a25c57 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -9,9 +9,9 @@ class NormalizeConfiguration(PoolRunnerConfiguration): pool_type: TPoolType = "process" destination_capabilities: DestinationCapabilitiesContext = None # injectable - schema_storage_config: SchemaVolumeConfiguration - normalize_storage_config: NormalizeVolumeConfiguration - load_storage_config: LoadVolumeConfiguration + _schema_storage_config: SchemaVolumeConfiguration + _normalize_storage_config: NormalizeVolumeConfiguration + _load_storage_config: LoadVolumeConfiguration if TYPE_CHECKING: def __init__( diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 27f63f4f7d..ccf78bf243 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -14,13 +14,13 @@ class PipelineConfiguration(BaseConfiguration): pipeline_name: Optional[str] = None working_dir: Optional[str] = None pipeline_secret: Optional[TSecretValue] = None - runtime: RunConfiguration + _runtime: RunConfiguration def check_integrity(self) -> None: if not self.pipeline_secret: self.pipeline_secret = TSecretValue(uniq_id()) if not self.pipeline_name: - self.pipeline_name = self.runtime.pipeline_name + self.pipeline_name = self._runtime.pipeline_name @configspec(init=True) diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index a9fc2a7e9e..711bf4ee83 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -264,41 +264,48 @@ class _SecretCredentials(RunConfiguration): assert dict(_SecretCredentials()) == expected_dict environment["SECRET_VALUE"] = "secret" - C = resolve.resolve_configuration(_SecretCredentials()) + c = resolve.resolve_configuration(_SecretCredentials()) expected_dict["secret_value"] = "secret" - assert dict(C) == expected_dict + assert dict(c) == expected_dict # check mutable mapping type - assert isinstance(C, MutableMapping) - assert isinstance(C, Mapping) - assert not isinstance(C, Dict) + assert isinstance(c, MutableMapping) + assert isinstance(c, Mapping) + assert not isinstance(c, Dict) # check view ops - assert C.keys() == expected_dict.keys() - assert len(C) == len(expected_dict) - assert C.items() == expected_dict.items() - assert list(C.values()) == list(expected_dict.values()) - for key in C: - assert C[key] == expected_dict[key] + assert c.keys() == expected_dict.keys() + assert len(c) == len(expected_dict) + assert c.items() == expected_dict.items() + assert list(c.values()) == list(expected_dict.values()) + for key in c: + assert c[key] == expected_dict[key] # version is present as attr but not present in dict - assert hasattr(C, "__is_resolved__") - assert hasattr(C, "__namespace__") + assert hasattr(c, "__is_resolved__") + assert hasattr(c, "__namespace__") # set ops # update supported and non existing attributes are ignored - C.update({"pipeline_name": "old pipe", "__version": "1.1.1"}) - assert C.pipeline_name == "old pipe" == C["pipeline_name"] + c.update({"pipeline_name": "old pipe", "__version": "1.1.1"}) + assert c.pipeline_name == "old pipe" == c["pipeline_name"] # delete is not supported with pytest.raises(KeyError): - del C["pipeline_name"] + del c["pipeline_name"] with pytest.raises(KeyError): - C.pop("pipeline_name", None) + c.pop("pipeline_name", None) # setting supported - C["pipeline_name"] = "new pipe" - assert C.pipeline_name == "new pipe" == C["pipeline_name"] + c["pipeline_name"] = "new pipe" + assert c.pipeline_name == "new pipe" == c["pipeline_name"] + with pytest.raises(KeyError): + c["unknown_prop"] = "unk" + + # also on new instance + c = SecretConfiguration() + with pytest.raises(KeyError): + c["unknown_prop"] = "unk" def test_fields_with_no_default_to_null(environment: Any) -> None: @@ -602,7 +609,7 @@ def test_do_not_resolve_twice(environment: Any) -> None: assert c2 is c3 is c4 # also c is resolved so c.secret_value = "else" - resolve.resolve_configuration(c).secret_value == "else" + assert resolve.resolve_configuration(c).secret_value == "else" def test_do_not_resolve_embedded(environment: Any) -> None: diff --git a/tests/common/configuration/test_namespaces.py b/tests/common/configuration/test_namespaces.py index 2e39721817..85369c8fdb 100644 --- a/tests/common/configuration/test_namespaces.py +++ b/tests/common/configuration/test_namespaces.py @@ -4,7 +4,6 @@ from dlt.common.configuration import configspec, ConfigFieldMissingException, resolve, inject_namespace from dlt.common.configuration.specs import BaseConfiguration, ConfigNamespacesContext -# from dlt.common.configuration.providers import environ as environ_provider from dlt.common.configuration.exceptions import LookupTrace from tests.utils import preserve_environ @@ -26,6 +25,22 @@ class EmbeddedWithNamespacedConfiguration(BaseConfiguration): embedded: NamespacedConfiguration +@configspec +class EmbeddedIgnoredConfiguration(BaseConfiguration): + # underscore prevents the field name to be added to embedded namespaces + _sv_config: Optional[SingleValConfiguration] + + +@configspec +class EmbeddedIgnoredWithNamespacedConfiguration(BaseConfiguration): + _embedded: NamespacedConfiguration + + +@configspec +class EmbeddedWithIgnoredEmbeddedConfiguration(BaseConfiguration): + ignored_embedded: EmbeddedIgnoredWithNamespacedConfiguration + + def test_namespaced_configuration(environment: Any) -> None: with pytest.raises(ConfigFieldMissingException) as exc_val: resolve.resolve_configuration(NamespacedConfiguration()) @@ -108,17 +123,38 @@ def test_overwrite_config_namespace_from_embedded(mock_provider: MockProvider) - def test_explicit_namespaces_from_embedded_config(mock_provider: MockProvider) -> None: mock_provider.value = {"sv": "A"} mock_provider.return_value_on = ("sv_config",) - C = resolve.resolve_configuration(EmbeddedConfiguration()) + c = resolve.resolve_configuration(EmbeddedConfiguration()) # we mock the dictionary below as the value for all requests - assert C.sv_config.sv == '{"sv": "A"}' + assert c.sv_config.sv == '{"sv": "A"}' # following namespaces were used when resolving EmbeddedConfig: () trying to get initial value for the whole embedded sv_config, then ("sv_config",), () to resolve sv in sv_config assert mock_provider.last_namespaces == [(), ("sv_config",)] # embedded namespace inner of explicit mock_provider.reset_stats() - C = resolve.resolve_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) + resolve.resolve_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) assert mock_provider.last_namespaces == [("ns1",), (), ("ns1", "sv_config",), ("sv_config",)] +def test_ignore_embedded_namespace_by_field_name(mock_provider: MockProvider) -> None: + mock_provider.value = {"sv": "A"} + resolve.resolve_configuration(EmbeddedIgnoredConfiguration()) + # _sv_config will not be added to embedded namespaces and looked up + assert mock_provider.last_namespaces == [()] + mock_provider.reset_stats() + resolve.resolve_configuration(EmbeddedIgnoredConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("ns1",), ()] + # if namespace config exist, it won't be replaced by embedded namespace + mock_provider.reset_stats() + mock_provider.value = {} + mock_provider.return_value_on = ("DLT_TEST",) + resolve.resolve_configuration(EmbeddedIgnoredWithNamespacedConfiguration()) + assert mock_provider.last_namespaces == [(), ("DLT_TEST",)] + # embedded configuration of depth 2: first normal, second - ignored + mock_provider.reset_stats() + mock_provider.return_value_on = ("DLT_TEST",) + resolve.resolve_configuration(EmbeddedWithIgnoredEmbeddedConfiguration()) + assert mock_provider.last_namespaces == [(), ('ignored_embedded',), ('ignored_embedded', 'DLT_TEST'), ('DLT_TEST',)] + + def test_injected_namespaces(mock_provider: MockProvider) -> None: container = Container() mock_provider.value = "value" From bc614666b438e6b5808599ad71369582ea42e5f5 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 31 Oct 2022 20:25:10 +0100 Subject: [PATCH 63/66] refactors interface to select resources in source, adds missing exceptions and typings --- dlt/extract/decorators.py | 30 ++++----- dlt/extract/exceptions.py | 39 +++++++++-- dlt/extract/extract.py | 7 +- dlt/extract/pipe.py | 2 +- dlt/extract/source.py | 132 +++++++++++++++++++++++++------------- dlt/extract/typing.py | 2 +- 6 files changed, 144 insertions(+), 68 deletions(-) diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 2316cd4b29..4f17525cdf 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -10,6 +10,7 @@ from dlt.common.schema.typing import TTableSchemaColumns, TWriteDisposition from dlt.common.typing import AnyFun, ParamSpec, TDataItems from dlt.common.utils import is_inner_function +from dlt.extract.exceptions import InvalidResourceDataTypeFunctionNotAGenerator from dlt.extract.typing import TTableHintTemplate, TFunHintTemplate from dlt.extract.source import DltResource, DltSource @@ -60,20 +61,21 @@ def decorator(f: Callable[TSourceFunParams, Any]) -> Callable[TSourceFunParams, @wraps(conf_f, func_name=name) def _wrap(*args: Any, **kwargs: Any) -> DltSource: rv = conf_f(*args, **kwargs) + # if generator, consume it immediately if inspect.isgenerator(rv): rv = list(rv) - def check_rv_type(rv: Any) -> None: - pass + # def check_rv_type(rv: Any) -> None: + # pass - # check if return type is list or tuple - if isinstance(rv, (list, tuple)): - # check all returned elements - for v in rv: - check_rv_type(v) - else: - check_rv_type(rv) + # # check if return type is list or tuple + # if isinstance(rv, (list, tuple)): + # # check all returned elements + # for v in rv: + # check_rv_type(v) + # else: + # check_rv_type(rv) # convert to source return DltSource.from_data(schema, rv) @@ -91,7 +93,7 @@ def check_rv_type(rv: Any) -> None: return decorator if not callable(func): - raise ValueError("First parameter to the source must be callable ie. by using it as function decorator") + raise ValueError("First parameter to the source must be a callable.") # we're called as @source without parens. return decorator(func) @@ -192,15 +194,14 @@ def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunPara resource_name = name or f.__name__ # if f is not a generator (does not yield) raise Exception - # if not inspect.isgeneratorfunction(inspect.unwrap(f)): - # raise ResourceFunNotGenerator() + if not inspect.isgeneratorfunction(inspect.unwrap(f)): + raise InvalidResourceDataTypeFunctionNotAGenerator(resource_name, f, type(f)) # do not inject config values for inner functions, we assume that they are part of the source SPEC: Type[BaseConfiguration] = None if is_inner_function(f): conf_f = f else: - print("USE SPEC -> GLOBAL") # wrap source extraction function in configuration with namespace conf_f = with_config(f, spec=spec, namespaces=("resource", resource_name)) # get spec for wrapped function @@ -215,8 +216,7 @@ def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunPara _SOURCES[f.__qualname__] = SourceInfo(SPEC, f, inspect.getmodule(f)) # the typing is right, but makefun.wraps does not preserve signatures - return make_resource(resource_name, f) # type: ignore - + return make_resource(resource_name, f) # if data is callable or none use decorator if data is None: diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index bf0d4bff18..a3fc162113 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -44,6 +44,17 @@ def __init__(self) -> None: Please note that for resources created from functions or generators, the name is the function name by default.""") +class DependentResourceIsNotCallable(DltResourceException): + def __init__(self, resource_name: str) -> None: + super().__init__(resource_name, f"Attempted to call the dependent resource {resource_name}. Do not call the dependent resources. They will be called only when iterated.") + + +class ResourceNotFoundError(DltResourceException, KeyError): + def __init__(self, resource_name: str, context: str) -> None: + self.resource_name = resource_name + super().__init__(resource_name, f"Resource with a name {resource_name} could not be found. {context}") + + class InvalidResourceDataType(DltResourceException): def __init__(self, resource_name: str, item: Any, _typ: Type[Any], msg: str) -> None: self.item = item @@ -51,20 +62,40 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any], msg: str) -> super().__init__(resource_name, f"Cannot create resource {resource_name} from specified data. " + msg) -class InvalidResourceAsyncDataType(InvalidResourceDataType): +class InvalidResourceDataTypeAsync(InvalidResourceDataType): def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: super().__init__(resource_name, item, _typ, "Async iterators and generators are not valid resources. Please use standard iterators and generators that yield Awaitables instead (for example by yielding from async function without await") -class InvalidResourceBasicDataType(InvalidResourceDataType): +class InvalidResourceDataTypeBasic(InvalidResourceDataType): def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: super().__init__(resource_name, item, _typ, f"Resources cannot be strings or dictionaries but {_typ.__name__} was provided. Please pass your data in a list or as a function yielding items. If you want to process just one data item, enclose it in a list.") -class GeneratorFunctionNotAllowedAsParentResource(DltResourceException): +class InvalidResourceDataTypeFunctionNotAGenerator(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, "Please make sure that function decorated with @resource uses 'yield' to return the data.") + + +class InvalidResourceDataTypeMultiplePipes(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, "Resources with multiple parallel data pipes are not yet supported. This problem most often happens when you are creating a source with @source decorator that has several resources with the same name.") + + +class InvalidDependentResourceDataTypeGeneratorFunctionRequired(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, "Dependent resource must be a decorated function that takes data item as its only argument.") + + +class InvalidParentResourceDataType(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, f"A parent resource of {resource_name} is of type {_typ.__name__}. Did you forget to use '@resource` decorator or `resource` function?") + + +class InvalidParentResourceIsAFunction(DltResourceException): def __init__(self, resource_name: str, func_name: str) -> None: self.func_name = func_name - super().__init__(resource_name, f"A parent resource {resource_name} of dependent resource {resource_name} is a function but must be a resource. Please decorate function") + super().__init__(resource_name, f"A parent resource {func_name} of dependent resource {resource_name} is a function. Please decorate it with '@resource' or pass to 'resource' function.") class TableNameMissing(DltSourceException): diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index e464a39031..418f9627c8 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -56,7 +56,6 @@ def _write_item(table_name: str, item: TDataItems) -> None: # normalize table name before writing so the name match the name in schema # note: normalize function should be cached so there's almost no penalty on frequent calling # note: column schema is not required for jsonl writer used here - # TODO: consider dropping DLT_METADATA_FIELD in all items before writing, this however takes CPU time # event.pop(DLT_METADATA_FIELD, None) # type: ignore storage.write_data_item(extract_id, schema.name, schema.normalize_table_name(table_name), item, None) @@ -78,9 +77,11 @@ def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: _write_item(table_name, item) # yield from all selected pipes - for pipe_item in PipeIterator.from_pipes(source.pipes, max_parallel_items=max_parallel_items, workers=workers, futures_poll_interval=futures_poll_interval): + for pipe_item in PipeIterator.from_pipes(source.resources.selected_pipes, max_parallel_items=max_parallel_items, workers=workers, futures_poll_interval=futures_poll_interval): # get partial table from table template - resource = source.resource_by_pipe(pipe_item.pipe) + # TODO: many resources may be returned. if that happens the item meta must be present with table name and this name must match one of resources + # TDataItemMeta(table_name, requires_resource, write_disposition, columns, parent etc.) + resource = source.resources.find_by_pipe(pipe_item.pipe) if resource._table_name_hint_fun: if isinstance(pipe_item.item, List): for item in pipe_item.item: diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index ae56b36a7f..5b71c5191d 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -191,7 +191,7 @@ def evaluate_head(self) -> None: # if pipe head is callable then call it if self.parent is None: if callable(self.head): - self._steps[0] = self.head() + self._steps[0] = self.head() # type: ignore def __repr__(self) -> str: return f"Pipe {self.name} ({self._pipe_id}) at {id(self)}" diff --git a/dlt/extract/source.py b/dlt/extract/source.py index fe99c3a1cb..85499fda57 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -2,18 +2,22 @@ from copy import deepcopy import inspect from collections.abc import Mapping as C_Mapping -from typing import AsyncIterable, AsyncIterator, Iterable, Iterator, List, Set, Sequence, Union, cast, Any +from typing import AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, List, Set, Sequence, Union, cast, Any +from typing_extensions import Self from dlt.common.schema import Schema from dlt.common.schema.utils import new_table from dlt.common.schema.typing import TPartialTableSchema, TTableSchemaColumns, TWriteDisposition -from dlt.common.typing import TDataItem, TDataItems +from dlt.common.typing import AnyFun, TDataItem, TDataItems from dlt.common.configuration.container import Container from dlt.common.pipeline import PipelineContext from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate, TTableSchemaTemplate from dlt.extract.pipe import FilterItem, Pipe, PipeIterator -from dlt.extract.exceptions import CreatePipeException, DataItemRequiredForDynamicTableHints, GeneratorFunctionNotAllowedAsParentResource, InconsistentTableTemplate, InvalidResourceAsyncDataType, InvalidResourceBasicDataType, ResourceNameMissing, TableNameMissing +from dlt.extract.exceptions import ( + DependentResourceIsNotCallable, InvalidDependentResourceDataTypeGeneratorFunctionRequired, InvalidParentResourceDataType, InvalidParentResourceIsAFunction, InvalidResourceDataType, InvalidResourceDataTypeFunctionNotAGenerator, + ResourceNotFoundError, CreatePipeException, DataItemRequiredForDynamicTableHints, InconsistentTableTemplate, InvalidResourceDataTypeAsync, InvalidResourceDataTypeBasic, + InvalidResourceDataTypeMultiplePipes, ResourceNameMissing, TableNameMissing) class DltResourceSchema: @@ -108,6 +112,7 @@ def new_table_template( class DltResource(Iterable[TDataItems], DltResourceSchema): def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate, selected: bool): + # TODO: allow resource to take name independent from pipe name self.name = pipe.name self.selected = selected self._pipe = pipe @@ -126,7 +131,7 @@ def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSch name = name or data.__name__ # function must be a generator if not inspect.isgeneratorfunction(inspect.unwrap(data)): - raise ResourceFunctionNotAGenerator(name) + raise InvalidResourceDataTypeFunctionNotAGenerator(name, data, type(data)) # if generator, take name from it if inspect.isgenerator(data): @@ -138,20 +143,20 @@ def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSch # several iterable types are not allowed and must be excluded right away if isinstance(data, (AsyncIterator, AsyncIterable)): - raise InvalidResourceAsyncDataType(name, data, type(data)) + raise InvalidResourceDataTypeAsync(name, data, type(data)) if isinstance(data, (str, dict)): - raise InvalidResourceBasicDataType(name, data, type(data)) + raise InvalidResourceDataTypeBasic(name, data, type(data)) # check if depends_on is a valid resource parent_pipe: Pipe = None if depends_on: + # must be a callable with single argument if not callable(data): - raise DependentResourceMustBeAGeneratorFunction() + raise InvalidDependentResourceDataTypeGeneratorFunctionRequired(name, data, type(data)) else: - pass - # TODO: check sig if takes just one argument - # if sig_valid(): - # raise DependentResourceMustTakeDataItemArgument() + if cls.is_valid_dependent_generator_function(data): + raise InvalidDependentResourceDataTypeGeneratorFunctionRequired(name, data, type(data)) + # parent resource if isinstance(depends_on, Pipe): parent_pipe = depends_on elif isinstance(depends_on, DltResource): @@ -159,10 +164,9 @@ def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSch else: # if this is generator function provide nicer exception if callable(depends_on): - raise GeneratorFunctionNotAllowedAsParentResource(depends_on.__name__) + raise InvalidParentResourceIsAFunction(name, depends_on.__name__) else: - raise ParentNotAResource() - + raise InvalidParentResourceDataType(name, depends_on, type(depends_on)) # create resource from iterator, iterable or generator function if isinstance(data, (Iterable, Iterator)): @@ -173,7 +177,13 @@ def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSch return cls(pipe, table_schema_template, selected) else: # some other data type that is not supported - raise InvalidResourceDataType("Invalid data type for DltResource", type(data)) + raise InvalidResourceDataType(name, data, type(data), f"The data type is {type(data).__name__}") + + + def add_pipe(self, data: Any) -> None: + """Creates additional pipe for the resource from the specified data""" + # TODO: (1) self resource cannot be a dependent one (2) if data is resource both self must and it must be selected/unselected + cannot be dependent + raise InvalidResourceDataTypeMultiplePipes(self.name, data, type(data)) def select(self, *table_names: Iterable[str]) -> "DltResource": @@ -187,18 +197,21 @@ def _filter(item: TDataItem) -> bool: self._pipe.add_step(FilterItem(_filter)) return self - def map(self) -> None: + def map(self) -> None: # noqa: A003 raise NotImplementedError() def flat_map(self) -> None: raise NotImplementedError() - def filter(self) -> None: + def filter(self) -> None: # noqa: A003 raise NotImplementedError() def __call__(self, *args: Any, **kwargs: Any) -> Any: # make resource callable to support parametrized resources which are functions taking arguments - _data = self._pipe.head(*args, **kwargs) + if self._pipe.parent: + raise DependentResourceIsNotCallable(self.name) + # pass the call parameters to the pipe's head + _data = self._pipe.head(*args, **kwargs) # type: ignore # create new resource from extracted data return DltResource.from_data(_data, self.name, self._table_schema_template, self.selected, self._pipe.parent) @@ -208,13 +221,53 @@ def __iter__(self) -> Iterator[TDataItems]: def __repr__(self) -> str: return f"DltResource {self.name} ({self._pipe._pipe_id}) at {id(self)}" + @staticmethod + def is_valid_dependent_generator_function(f: AnyFun) -> bool: + sig = inspect.signature(f) + return len(sig.parameters) == 0 + + +class DltResourceDict(Dict[str, DltResource]): + @property + def selected(self) -> Dict[str, DltResource]: + return {k:v for k,v in self.items() if v.selected} + + @property + def pipes(self) -> List[Pipe]: + # TODO: many resources may share the same pipe so return ordered set + return [r._pipe for r in self.values()] + + @property + def selected_pipes(self) -> Sequence[Pipe]: + # TODO: many resources may share the same pipe so return ordered set + return [r._pipe for r in self.values() if r.selected] + + def select(self, *resource_names: str) -> Dict[str, DltResource]: + # checks if keys are present + for name in resource_names: + try: + self.__getitem__(name) + except KeyError: + raise ResourceNotFoundError(name, "Requested resource could not be selected because it is not present in the source.") + # set the selected flags + for resource in self.values(): + self[resource.name].selected = resource.name in resource_names + return self.selected + + def find_by_pipe(self, pipe: Pipe) -> DltResource: + # TODO: many resources may share the same pipe so return a list and also filter the resources by self._enabled_resource_names + # identify pipes by memory pointer + return next(r for r in self.values() if r._pipe._pipe_id is pipe._pipe_id) + class DltSource(Iterable[TDataItems]): def __init__(self, schema: Schema, resources: Sequence[DltResource] = None) -> None: self.name = schema.name self._schema = schema - self._resources: List[DltResource] = list(resources or []) - self._enabled_resource_names: Set[str] = set(r.name for r in self._resources if r.selected) + self._resources: DltResourceDict = DltResourceDict() + if resources: + for resource in resources: + self._add_resource(resource) @classmethod def from_data(cls, schema: Schema, data: Any) -> "DltSource": @@ -222,10 +275,6 @@ def from_data(cls, schema: Schema, data: Any) -> "DltSource": if isinstance(data, DltSource): return data - # several iterable types are not allowed and must be excluded right away - if isinstance(data, (AsyncIterator, AsyncIterable, str, dict)): - raise InvalidSourceDataType("Invalid data type for DltSource", type(data)) - # in case of sequence, enumerate items and convert them into resources if isinstance(data, Sequence): resources = [DltResource.from_data(i) for i in data] @@ -235,22 +284,13 @@ def from_data(cls, schema: Schema, data: Any) -> "DltSource": return cls(schema, resources) - def __getitem__(self, name: str) -> List[DltResource]: - if name not in self._enabled_resource_names: - raise KeyError(name) - return [r for r in self._resources if r.name == name] - - def resource_by_pipe(self, pipe: Pipe) -> DltResource: - # identify pipes by memory pointer - return next(r for r in self._resources if r._pipe._pipe_id is pipe._pipe_id) - @property - def resources(self) -> Sequence[DltResource]: - return [r for r in self._resources if r.name in self._enabled_resource_names] + def resources(self) -> DltResourceDict: + return self._resources @property - def pipes(self) -> Sequence[Pipe]: - return [r._pipe for r in self._resources if r.name in self._enabled_resource_names] + def selected_resources(self) -> Dict[str, DltResource]: + return self._resources.selected @property def schema(self) -> Schema: @@ -262,26 +302,30 @@ def schema(self, value: Schema) -> None: def discover_schema(self) -> Schema: # extract tables from all resources and update internal schema - for r in self._resources: + for r in self._resources.values(): # names must be normalized here with contextlib.suppress(DataItemRequiredForDynamicTableHints): partial_table = self._schema.normalize_table_identifiers(r.table_schema()) self._schema.update_schema(partial_table) return self._schema - def select(self, *resource_names: str) -> "DltSource": - # make sure all selected resources exist - for name in resource_names: - self.__getitem__(name) - self._enabled_resource_names = set(resource_names) + def with_resources(self, *resource_names: str) -> "DltSource": + self._resources.select(*resource_names) return self def run(self, destination: Any) -> Any: return Container()[PipelineContext].pipeline().run(source=self, destination=destination) + def _add_resource(self, resource: DltResource) -> None: + if resource.name in self._resources: + # for resources with the same name try to add the resource as an another pipe + self._resources[resource.name].add_pipe(resource) + else: + self._resources[resource.name] = resource + def __iter__(self) -> Iterator[TDataItems]: - return map(lambda item: item.item, PipeIterator.from_pipes(self.pipes)) + return map(lambda item: item.item, PipeIterator.from_pipes(self._resources.selected_pipes)) def __repr__(self) -> str: return f"DltSource {self.name} at {id(self)}" diff --git a/dlt/extract/typing.py b/dlt/extract/typing.py index 7c79453eb3..b7b33b6c65 100644 --- a/dlt/extract/typing.py +++ b/dlt/extract/typing.py @@ -15,7 +15,7 @@ class TTableSchemaTemplate(TypedDict, total=False): name: TTableHintTemplate[str] - description: TTableHintTemplate[str] + # description: TTableHintTemplate[str] write_disposition: TTableHintTemplate[TWriteDisposition] # table_sealed: Optional[bool] parent: TTableHintTemplate[str] From 161b8d6dc879931fa4ec8c1114a848a7e95a74d8 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 31 Oct 2022 20:26:18 +0100 Subject: [PATCH 64/66] changes the schema file pattern _schema. -> .schema. --- dlt/common/storages/normalize_storage.py | 3 --- dlt/common/storages/schema_storage.py | 4 ++-- .../{discord_schema.yml => discord.schema.yml} | 0 .../{hubspot_schema.yml => hubspot.schema.yml} | 0 ...red_demo_schema.yml => inferred_demo.schema.yml} | 0 .../schemas/{rasa_schema.yml => rasa.schema.yml} | 0 .../ev1/{event_schema.7z => event.schema.7z} | Bin .../ev1/{event_schema.json => event.schema.json} | 0 .../ev1/{model_schema.json => model.schema.json} | 0 .../ev2/{event_schema.json => event.schema.json} | 0 .../rasa/{event_schema.json => event.schema.json} | 0 .../rasa/{model_schema.json => model.schema.json} | 0 tests/common/configuration/test_container.py | 6 +++--- tests/common/configuration/test_environ_provider.py | 6 +++--- tests/common/configuration/utils.py | 2 +- tests/common/runners/test_runnable.py | 10 ++++------ tests/common/schema/test_coercion.py | 6 +++--- tests/common/schema/test_inference.py | 2 +- tests/common/schema/test_schema.py | 12 ++++++------ tests/common/schema/test_versioning.py | 2 +- tests/common/storages/test_schema_storage.py | 4 +--- tests/common/test_validation.py | 2 +- .../cases/{event_schema.json => event.schema.json} | 0 .../{ethereum_schema.json => ethereum.schema.json} | 0 .../{event_schema.json => event.schema.json} | 0 25 files changed, 26 insertions(+), 33 deletions(-) rename examples/schemas/{discord_schema.yml => discord.schema.yml} (100%) rename examples/schemas/{hubspot_schema.yml => hubspot.schema.yml} (100%) rename examples/schemas/{inferred_demo_schema.yml => inferred_demo.schema.yml} (100%) rename examples/schemas/{rasa_schema.yml => rasa.schema.yml} (100%) rename tests/common/cases/schemas/ev1/{event_schema.7z => event.schema.7z} (100%) rename tests/common/cases/schemas/ev1/{event_schema.json => event.schema.json} (100%) rename tests/common/cases/schemas/ev1/{model_schema.json => model.schema.json} (100%) rename tests/common/cases/schemas/ev2/{event_schema.json => event.schema.json} (100%) rename tests/common/cases/schemas/rasa/{event_schema.json => event.schema.json} (100%) rename tests/common/cases/schemas/rasa/{model_schema.json => model.schema.json} (100%) rename tests/load/cases/{event_schema.json => event.schema.json} (100%) rename tests/normalize/cases/schemas/{ethereum_schema.json => ethereum.schema.json} (100%) rename tests/normalize/cases/schemas/{event_schema.json => event.schema.json} (100%) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index e20652e7b5..b05c3adf2a 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -32,13 +32,10 @@ def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration = Config def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration = ConfigValue) -> None: super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(config.normalize_volume_path, "t", makedirs=is_owner)) self.config = config - print(is_owner) if is_owner: self.initialize_storage() def initialize_storage(self) -> None: - print(self.storage.storage_path) - print(NormalizeStorage.EXTRACTED_FOLDER) self.storage.create_folder(NormalizeStorage.EXTRACTED_FOLDER, exists_ok=True) def list_files_to_normalize_sorted(self) -> Sequence[str]: diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 309469695b..0127fc2b6f 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -16,7 +16,7 @@ class SchemaStorage(Mapping[str, Schema]): SCHEMA_FILE_NAME = "schema.%s" - NAMED_SCHEMA_FILE_PATTERN = f"%s_{SCHEMA_FILE_NAME}" + NAMED_SCHEMA_FILE_PATTERN = f"%s.{SCHEMA_FILE_NAME}" @overload def __init__(self, config: SchemaVolumeConfiguration, makedirs: bool = False) -> None: @@ -77,7 +77,7 @@ def has_schema(self, name: str) -> bool: def list_schemas(self) -> List[str]: files = self.storage.list_folder_files(".", to_root=False) # extract names - return [re.split("_|schema", f)[0] for f in files] + return [f.split(".")[0] for f in files] def __getitem__(self, name: str) -> Schema: return self.load_schema(name) diff --git a/examples/schemas/discord_schema.yml b/examples/schemas/discord.schema.yml similarity index 100% rename from examples/schemas/discord_schema.yml rename to examples/schemas/discord.schema.yml diff --git a/examples/schemas/hubspot_schema.yml b/examples/schemas/hubspot.schema.yml similarity index 100% rename from examples/schemas/hubspot_schema.yml rename to examples/schemas/hubspot.schema.yml diff --git a/examples/schemas/inferred_demo_schema.yml b/examples/schemas/inferred_demo.schema.yml similarity index 100% rename from examples/schemas/inferred_demo_schema.yml rename to examples/schemas/inferred_demo.schema.yml diff --git a/examples/schemas/rasa_schema.yml b/examples/schemas/rasa.schema.yml similarity index 100% rename from examples/schemas/rasa_schema.yml rename to examples/schemas/rasa.schema.yml diff --git a/tests/common/cases/schemas/ev1/event_schema.7z b/tests/common/cases/schemas/ev1/event.schema.7z similarity index 100% rename from tests/common/cases/schemas/ev1/event_schema.7z rename to tests/common/cases/schemas/ev1/event.schema.7z diff --git a/tests/common/cases/schemas/ev1/event_schema.json b/tests/common/cases/schemas/ev1/event.schema.json similarity index 100% rename from tests/common/cases/schemas/ev1/event_schema.json rename to tests/common/cases/schemas/ev1/event.schema.json diff --git a/tests/common/cases/schemas/ev1/model_schema.json b/tests/common/cases/schemas/ev1/model.schema.json similarity index 100% rename from tests/common/cases/schemas/ev1/model_schema.json rename to tests/common/cases/schemas/ev1/model.schema.json diff --git a/tests/common/cases/schemas/ev2/event_schema.json b/tests/common/cases/schemas/ev2/event.schema.json similarity index 100% rename from tests/common/cases/schemas/ev2/event_schema.json rename to tests/common/cases/schemas/ev2/event.schema.json diff --git a/tests/common/cases/schemas/rasa/event_schema.json b/tests/common/cases/schemas/rasa/event.schema.json similarity index 100% rename from tests/common/cases/schemas/rasa/event_schema.json rename to tests/common/cases/schemas/rasa/event.schema.json diff --git a/tests/common/cases/schemas/rasa/model_schema.json b/tests/common/cases/schemas/rasa/model.schema.json similarity index 100% rename from tests/common/cases/schemas/rasa/model_schema.json rename to tests/common/cases/schemas/rasa/model.schema.json diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py index 2581174e93..45f0e29738 100644 --- a/tests/common/configuration/test_container.py +++ b/tests/common/configuration/test_container.py @@ -57,7 +57,7 @@ def test_get_default_injectable_config(container: Container) -> None: def test_raise_on_no_default_value(container: Container) -> None: - with pytest.raises(ContextDefaultCannotBeCreated) as py_ex: + with pytest.raises(ContextDefaultCannotBeCreated): container[NoDefaultInjectableContext] # ok when injected @@ -129,8 +129,8 @@ def test_container_provider(container: Container) -> None: provider.get_value("n/a", InjectableTestContext, ("ns1",)) # type hints that are not classes - l = Literal["a"] - v, k = provider.get_value("n/a", l) + literal = Literal["a"] + v, k = provider.get_value("n/a", literal) assert v is None assert k == "typing.Literal['a']" diff --git a/tests/common/configuration/test_environ_provider.py b/tests/common/configuration/test_environ_provider.py index 9f7a2c2885..87cb1de30e 100644 --- a/tests/common/configuration/test_environ_provider.py +++ b/tests/common/configuration/test_environ_provider.py @@ -98,9 +98,9 @@ def test_configuration_files(environment: Any) -> None: C = resolve.resolve_configuration(MockProdConfigurationVar()) assert C.config_files_storage_path == environment["CONFIG_FILES_STORAGE_PATH"] assert C.has_configuration_file("hasn't") is False - assert C.has_configuration_file("event_schema.json") is True - assert C.get_configuration_file_path("event_schema.json") == "./tests/common/cases/schemas/ev1/event_schema.json" - with C.open_configuration_file("event_schema.json", "r") as f: + assert C.has_configuration_file("event.schema.json") is True + assert C.get_configuration_file_path("event.schema.json") == "./tests/common/cases/schemas/ev1/event.schema.json" + with C.open_configuration_file("event.schema.json", "r") as f: f.read() with pytest.raises(ConfigFileNotFoundException): C.open_configuration_file("hasn't", "r") diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py index 9ba58bb5e8..e03a123c47 100644 --- a/tests/common/configuration/utils.py +++ b/tests/common/configuration/utils.py @@ -93,7 +93,7 @@ def reset_stats(self) -> None: def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: self.last_namespace = namespaces self.last_namespaces.append(namespaces) - print("|".join(namespaces) + "-" + key) + # print("|".join(namespaces) + "-" + key) if namespaces == self.return_value_on: rv = self.value else: diff --git a/tests/common/runners/test_runnable.py b/tests/common/runners/test_runnable.py index 4725fd6fcf..c5eb276dda 100644 --- a/tests/common/runners/test_runnable.py +++ b/tests/common/runners/test_runnable.py @@ -3,8 +3,7 @@ from multiprocessing.pool import Pool from multiprocessing.dummy import Pool as ThreadPool -from dlt.common.utils import uniq_id -from dlt.normalize.configuration import NormalizeConfiguration +from dlt.normalize.configuration import SchemaVolumeConfiguration from tests.common.runners.utils import _TestRunnable from tests.utils import skipifspawn @@ -68,9 +67,8 @@ def test_weak_pool_ref() -> None: def test_configuredworker() -> None: - # call worker method with CONFIG values that should be restored into CONFIG type - config = NormalizeConfiguration() + config = SchemaVolumeConfiguration() config["import_schema_path"] = "test_schema_path" _worker_1(config, "PX1", par2="PX2") @@ -79,9 +77,9 @@ def test_configuredworker() -> None: p.starmap(_worker_1, [(config, "PX1", "PX2")]) -def _worker_1(CONFIG: NormalizeConfiguration, par1: str, par2: str = "DEFAULT") -> None: +def _worker_1(CONFIG: SchemaVolumeConfiguration, par1: str, par2: str = "DEFAULT") -> None: # a correct type was passed - assert type(CONFIG) is NormalizeConfiguration + assert type(CONFIG) is SchemaVolumeConfiguration # check if config values are restored assert CONFIG.import_schema_path == "test_schema_path" # check if other parameters are correctly diff --git a/tests/common/schema/test_coercion.py b/tests/common/schema/test_coercion.py index 660ee22ce2..2c82a95429 100644 --- a/tests/common/schema/test_coercion.py +++ b/tests/common/schema/test_coercion.py @@ -156,13 +156,13 @@ def test_coerce_type_to_timestamp() -> None: # test wrong unix timestamps with pytest.raises(ValueError): - print(utils.coerce_type("timestamp", "double", -1000000000000000000000000000)) + utils.coerce_type("timestamp", "double", -1000000000000000000000000000) with pytest.raises(ValueError): - print(utils.coerce_type("timestamp", "double", 1000000000000000000000000000)) + utils.coerce_type("timestamp", "double", 1000000000000000000000000000) # formats with timezones are not parsed with pytest.raises(ValueError): - print(utils.coerce_type("timestamp", "text", "06/04/22, 11:15PM IST")) + utils.coerce_type("timestamp", "text", "06/04/22, 11:15PM IST") # we do not parse RFC 822, 2822, 850 etc. with pytest.raises(ValueError): diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index 635e6a47d2..73a9021def 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -125,7 +125,7 @@ def test_coerce_row(schema: Schema) -> None: schema.update_schema(new_table) with pytest.raises(CannotCoerceColumnException) as exc_val: # now pass the binary that would create binary variant - but the column is occupied by text type - print(schema.coerce_row("event_user", None, {"new_colbool": pendulum.now()})) + schema.coerce_row("event_user", None, {"new_colbool": pendulum.now()}) assert exc_val.value.table_name == "event_user" assert exc_val.value.column_name == "new_colbool__v_timestamp" assert exc_val.value.from_type == "timestamp" diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 37ab37630e..b6fd7abb63 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -16,7 +16,7 @@ from tests.common.utils import load_json_case, load_yml_case SCHEMA_NAME = "event" -EXPECTED_FILE_NAME = f"{SCHEMA_NAME}_schema.json" +EXPECTED_FILE_NAME = f"{SCHEMA_NAME}.schema.json" @pytest.fixture @@ -158,7 +158,7 @@ def test_save_store_schema_custom_normalizers(cn_schema: Schema, schema_storage: def test_upgrade_engine_v1_schema() -> None: - schema_dict: DictStrAny = load_json_case("schemas/ev1/event_schema") + schema_dict: DictStrAny = load_json_case("schemas/ev1/event.schema") # ensure engine v1 assert schema_dict["engine_version"] == 1 # schema_dict will be updated to new engine version @@ -168,14 +168,14 @@ def test_upgrade_engine_v1_schema() -> None: assert len(schema_dict["tables"]) == 27 # upgrade schema eng 2 -> 4 - schema_dict: DictStrAny = load_json_case("schemas/ev2/event_schema") + schema_dict: DictStrAny = load_json_case("schemas/ev2/event.schema") assert schema_dict["engine_version"] == 2 upgraded = utils.upgrade_engine_version(schema_dict, from_engine=2, to_engine=4) assert upgraded["engine_version"] == 4 utils.validate_stored_schema(upgraded) # upgrade 1 -> 4 - schema_dict: DictStrAny = load_json_case("schemas/ev1/event_schema") + schema_dict: DictStrAny = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 upgraded = utils.upgrade_engine_version(schema_dict, from_engine=1, to_engine=4) assert upgraded["engine_version"] == 4 @@ -183,7 +183,7 @@ def test_upgrade_engine_v1_schema() -> None: def test_unknown_engine_upgrade() -> None: - schema_dict: TStoredSchema = load_json_case("schemas/ev1/event_schema") + schema_dict: TStoredSchema = load_json_case("schemas/ev1/event.schema") # there's no path to migrate 3 -> 2 schema_dict["engine_version"] = 3 with pytest.raises(SchemaEngineNoUpgradePathException): @@ -242,7 +242,7 @@ def test_rasa_event_hints(columns: Sequence[str], hint: str, value: bool, schema def test_filter_hints_table() -> None: # this schema contains event_bot table with expected hints - schema_dict: TStoredSchema = load_json_case("schemas/ev1/event_schema") + schema_dict: TStoredSchema = load_json_case("schemas/ev1/event.schema") schema = Schema.from_dict(schema_dict) # get all not_null columns on event bot_case: StrAny = load_json_case("mod_bot_case") diff --git a/tests/common/schema/test_versioning.py b/tests/common/schema/test_versioning.py index b8be8be019..79d23eb417 100644 --- a/tests/common/schema/test_versioning.py +++ b/tests/common/schema/test_versioning.py @@ -115,7 +115,7 @@ def test_version_preserve_on_reload(remove_defaults: bool) -> None: assert saved_schema.stored_version_hash == schema.stored_version_hash # serialize as yaml, for that use a schema that was stored in json - rasa_v4: TStoredSchema = load_json_case("schemas/rasa/event_schema") + rasa_v4: TStoredSchema = load_json_case("schemas/rasa/event.schema") rasa_schema = Schema.from_dict(rasa_v4) rasa_yml = rasa_schema.to_pretty_yaml(remove_defaults=remove_defaults) saved_rasa_schema = Schema.from_dict(yaml.safe_load(rasa_yml)) diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 2beed89c92..341fdfd2e5 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -4,11 +4,9 @@ import yaml from dlt.common import json -from dlt.common.typing import DictStrAny from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema from dlt.common.schema.utils import default_normalizers -from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError from dlt.common.storages import SchemaStorage, LiveSchemaStorage, FileStorage @@ -245,7 +243,7 @@ def test_save_store_schema(storage: SchemaStorage) -> None: def prepare_import_folder(storage: SchemaStorage) -> None: - shutil.copy(yml_case_path("schemas/eth/ethereum_schema_v4"), storage.storage.make_full_path("../import/ethereum_schema.yaml")) + shutil.copy(yml_case_path("schemas/eth/ethereum_schema_v4"), storage.storage.make_full_path("../import/ethereum.schema.yaml")) def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: diff --git a/tests/common/test_validation.py b/tests/common/test_validation.py index 019752c200..4bc516615b 100644 --- a/tests/common/test_validation.py +++ b/tests/common/test_validation.py @@ -88,7 +88,7 @@ def test_validate_schema_cases() -> None: validate_dict(TStoredSchema, schema_dict, ".", lambda k: not k.startswith("x-"), simple_regex_validator) - # with open("tests/common/cases/schemas/rasa/event_schema.json") as f: + # with open("tests/common/cases/schemas/rasa/event.schema.json") as f: # schema_dict: TStoredSchema = json.load(f) # validate_dict(TStoredSchema, schema_dict, ".", lambda k: not k.startswith("x-")) diff --git a/tests/load/cases/event_schema.json b/tests/load/cases/event.schema.json similarity index 100% rename from tests/load/cases/event_schema.json rename to tests/load/cases/event.schema.json diff --git a/tests/normalize/cases/schemas/ethereum_schema.json b/tests/normalize/cases/schemas/ethereum.schema.json similarity index 100% rename from tests/normalize/cases/schemas/ethereum_schema.json rename to tests/normalize/cases/schemas/ethereum.schema.json diff --git a/tests/normalize/cases/schemas/event_schema.json b/tests/normalize/cases/schemas/event.schema.json similarity index 100% rename from tests/normalize/cases/schemas/event_schema.json rename to tests/normalize/cases/schemas/event.schema.json From 32d769b957e9d8c0049eb3aecae3d9d1db2310b5 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 31 Oct 2022 20:29:16 +0100 Subject: [PATCH 65/66] adds flag to wipe out loader storage before initializing --- dlt/common/destination.py | 6 ++++-- dlt/load/bigquery/bigquery.py | 4 +++- dlt/load/bigquery/configuration.py | 1 - dlt/load/dummy/__init__.py | 1 - dlt/load/dummy/dummy.py | 2 +- dlt/load/load.py | 2 +- dlt/load/redshift/redshift.py | 4 +++- 7 files changed, 12 insertions(+), 8 deletions(-) diff --git a/dlt/common/destination.py b/dlt/common/destination.py index 328a332450..5489ebe5a4 100644 --- a/dlt/common/destination.py +++ b/dlt/common/destination.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from importlib import import_module from nis import cat -from types import TracebackType +from types import ModuleType, TracebackType from typing import ClassVar, List, Optional, Literal, Type, Protocol, Union, TYPE_CHECKING, cast from dlt.common.schema import Schema @@ -99,7 +99,7 @@ def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> No self.config = config @abstractmethod - def initialize_storage(self) -> None: + def initialize_storage(self, wipe_data: bool = False) -> None: pass @abstractmethod @@ -133,6 +133,8 @@ def capabilities(cls) -> DestinationCapabilitiesContext: class DestinationReference(Protocol): + __name__: str + def capabilities(self) -> DestinationCapabilitiesContext: ... diff --git a/dlt/load/bigquery/bigquery.py b/dlt/load/bigquery/bigquery.py index 7eee8ac7c7..db73a0b746 100644 --- a/dlt/load/bigquery/bigquery.py +++ b/dlt/load/bigquery/bigquery.py @@ -199,7 +199,9 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: self.config: BigQueryClientConfiguration = config self.sql_client: BigQuerySqlClient = sql_client - def initialize_storage(self) -> None: + def initialize_storage(self, wipe_data: bool = False) -> None: + if wipe_data: + raise NotImplementedError() if not self.sql_client.has_dataset(): self.sql_client.create_dataset() diff --git a/dlt/load/bigquery/configuration.py b/dlt/load/bigquery/configuration.py index 865dd0db35..496d9b0f05 100644 --- a/dlt/load/bigquery/configuration.py +++ b/dlt/load/bigquery/configuration.py @@ -21,6 +21,5 @@ def check_integrity(self) -> None: # set the project id - it needs to be known by the client self.credentials.project_id = self.credentials.project_id or project_id except DefaultCredentialsError: - print("DefaultCredentialsError") # re-raise preventing exception raise self.credentials.__exception__ diff --git a/dlt/load/dummy/__init__.py b/dlt/load/dummy/__init__.py index d4906310b1..b29ba69807 100644 --- a/dlt/load/dummy/__init__.py +++ b/dlt/load/dummy/__init__.py @@ -10,7 +10,6 @@ @with_config(spec=DummyClientConfiguration, namespaces=("destination", "dummy",)) def _configure(config: DummyClientConfiguration = ConfigValue) -> DummyClientConfiguration: - print(dict(config)) return config diff --git a/dlt/load/dummy/dummy.py b/dlt/load/dummy/dummy.py index 1976b89b65..52d610b7a8 100644 --- a/dlt/load/dummy/dummy.py +++ b/dlt/load/dummy/dummy.py @@ -81,7 +81,7 @@ def __init__(self, schema: Schema, config: DummyClientConfiguration) -> None: super().__init__(schema, config) self.config: DummyClientConfiguration = config - def initialize_storage(self) -> None: + def initialize_storage(self, wipe_data: bool = False) -> None: pass def update_storage_schema(self) -> None: diff --git a/dlt/load/load.py b/dlt/load/load.py index 2846613f36..dbad1ce523 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -53,7 +53,7 @@ def create_storage(self, is_storage_owner: bool) -> LoadStorage: is_storage_owner, self.capabilities.preferred_loader_file_format, self.capabilities.supported_loader_file_formats, - config=self.config.load_storage_config + config=self.config._load_storage_config ) return load_storage diff --git a/dlt/load/redshift/redshift.py b/dlt/load/redshift/redshift.py index ad677baaac..7764460e35 100644 --- a/dlt/load/redshift/redshift.py +++ b/dlt/load/redshift/redshift.py @@ -213,7 +213,9 @@ def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: self.config: RedshiftClientConfiguration = config self.sql_client = sql_client - def initialize_storage(self) -> None: + def initialize_storage(self, wipe_data: bool = False) -> None: + if wipe_data: + raise NotImplementedError() if not self.sql_client.has_dataset(): self.sql_client.create_dataset() From 11a80d890917c0808df00cd47a5d0f79c9fdca8b Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Mon, 31 Oct 2022 20:30:04 +0100 Subject: [PATCH 66/66] various typing improvements --- dlt/common/typing.py | 6 +- dlt/helpers/streamlit.py | 340 +++++++++++++++++------------------ dlt/normalize/normalize.py | 6 +- dlt/pipeline/__init__.py | 3 +- dlt/pipeline/exceptions.py | 26 +-- dlt/pipeline/pipeline.py | 6 +- examples/google_drive_csv.py | 3 +- mypy.ini | 1 + poetry.lock | 277 +++++++--------------------- pyproject.toml | 9 +- 10 files changed, 251 insertions(+), 426 deletions(-) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 72d7847398..b29d8a495d 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,18 +1,18 @@ from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from re import Pattern as _REPattern from typing import Callable, Dict, Any, Literal, List, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, ParamSpec if TYPE_CHECKING: from _typeshed import StrOrBytesPath - from typing_extensions import ParamSpec + # from typing_extensions import ParamSpec from typing import _TypedDict REPattern = _REPattern[str] else: StrOrBytesPath = Any from typing import _TypedDictMeta as _TypedDict REPattern = _REPattern - ParamSpec = lambda x: [x] + # ParamSpec = lambda x: [x] DictStrAny: TypeAlias = Dict[str, Any] DictStrStr: TypeAlias = Dict[str, str] diff --git a/dlt/helpers/streamlit.py b/dlt/helpers/streamlit.py index f063ff2273..383959d55f 100644 --- a/dlt/helpers/streamlit.py +++ b/dlt/helpers/streamlit.py @@ -1,170 +1,170 @@ -import os -import tomlkit -from tomlkit.container import Container as TomlContainer -from typing import cast -from copy import deepcopy - -from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration -from dlt.common.utils import dict_remove_nones_in_place - -from dlt.pipeline import Pipeline -from dlt.pipeline.typing import credentials_from_dict -from dlt.pipeline.exceptions import MissingDependencyException, PipelineException -from dlt.helpers.pandas import query_results_to_df, pd - -try: - import streamlit as st - from streamlit import SECRETS_FILE_LOC, secrets -except ImportError: - raise MissingDependencyException("DLT Streamlit Helpers", ["streamlit"], "DLT Helpers for Streamlit should be run within a streamlit app.") - - -def restore_pipeline() -> Pipeline: - """Restores Pipeline instance and associated credentials from Streamlit secrets - - Current implementation requires that pipeline working dir is available at the location saved in secrets. - - Raises: - PipelineBackupNotFound: Raised when pipeline backup is not available - CannotRestorePipelineException: Raised when pipeline working dir is not found or invalid - - Returns: - Pipeline: Instance of pipeline with attached credentials - """ - if "dlt" not in secrets: - raise PipelineException("You must backup pipeline to Streamlit first") - dlt_cfg = secrets["dlt"] - credentials = deepcopy(dict(dlt_cfg["destination"])) - if "default_schema_name" in credentials: - del credentials["default_schema_name"] - credentials.update(dlt_cfg["credentials"]) - pipeline = Pipeline(dlt_cfg["pipeline_name"]) - pipeline.restore_pipeline(credentials_from_dict(credentials), dlt_cfg["working_dir"]) - return pipeline - - -def backup_pipeline(pipeline: Pipeline) -> None: - """Backups pipeline state to the `secrets.toml` of the Streamlit app. - - Pipeline credentials and working directory will be added to the Streamlit `secrets` file. This allows to access query the data loaded to the destination and - access definitions of the inferred schemas. See `restore_pipeline` and `write_data_explorer_page` functions in the same module. - - Args: - pipeline (Pipeline): Pipeline instance, typically restored with `restore_pipeline` - """ - # save pipeline state to project .config - # config_file_name = file_util.get_project_streamlit_file_path("config.toml") - - # save credentials to secrets - if os.path.isfile(SECRETS_FILE_LOC): - with open(SECRETS_FILE_LOC, "r", encoding="utf-8") as f: - # use whitespace preserving parser - secrets_ = tomlkit.load(f) - else: - secrets_ = tomlkit.document() - - # save general settings - secrets_["dlt"] = { - "working_dir": pipeline.working_dir, - "pipeline_name": pipeline.pipeline_name - } - - # get client config - # TODO: pipeline api v2 should provide a direct method to get configurations - CONFIG: BaseConfiguration = pipeline._loader_instance.load_client_cls.CONFIG # type: ignore - CREDENTIALS: CredentialsConfiguration = pipeline._loader_instance.load_client_cls.CREDENTIALS # type: ignore - - # save client config - # print(dict_remove_nones_in_place(CONFIG.as_dict(lowercase=False))) - dlt_c = cast(TomlContainer, secrets_["dlt"]) - dlt_c["destination"] = dict_remove_nones_in_place(dict(CONFIG)) - dlt_c["credentials"] = dict_remove_nones_in_place(dict(CREDENTIALS)) - - with open(SECRETS_FILE_LOC, "w", encoding="utf-8") as f: - # use whitespace preserving parser - tomlkit.dump(secrets_, f) - - -def write_data_explorer_page(pipeline: Pipeline, schema_name: str = None, show_dlt_tables: bool = False, example_query: str = "", show_charts: bool = True) -> None: - """Writes Streamlit app page with a schema and live data preview. - - Args: - pipeline (Pipeline): Pipeline instance to use. - schema_name (str, optional): Name of the schema to display. If None, default schema is used. - show_dlt_tables (bool, optional): Should show DLT internal tables. Defaults to False. - example_query (str, optional): Example query to be displayed in the SQL Query box. - show_charts (bool, optional): Should automatically show charts for the queries from SQL Query box. Defaults to True. - - Raises: - MissingDependencyException: Raised when a particular python dependency is not installed - """ - @st.experimental_memo(ttl=600) - def run_query(query: str) -> pd.DataFrame: - # dlt pipeline exposes configured sql client that (among others) let's you make queries against the warehouse - with pipeline.sql_client(schema_name) as client: - df = query_results_to_df(client, query) - return df - - if schema_name: - schema = pipeline.get_schema(schema_name) - else: - schema = pipeline.get_default_schema() - st.title(f"Available tables in {schema.name} schema") - # st.text(schema.to_pretty_yaml()) - - for table in schema.all_tables(with_dlt_tables=show_dlt_tables): - table_name = table["name"] - st.header(table_name) - if "description" in table: - st.text(table["description"]) - if "parent" in table: - st.text("Parent table: " + table["parent"]) - - # table schema contains various hints (like clustering or partition options) that we do not want to show in basic view - essentials_f = lambda c: {k:v for k, v in c.items() if k in ["name", "data_type", "nullable"]} - - st.table(map(essentials_f, table["columns"].values())) - # add a button that when pressed will show the full content of a table - if st.button("SHOW DATA", key=table_name): - st.text(f"Full {table_name} table content") - st.dataframe(run_query(f"SELECT * FROM {table_name}")) - - st.title("Run your query") - sql_query = st.text_area("Enter your SQL query", value=example_query) - if st.button("Run Query"): - if sql_query: - st.text("Results of a query") - try: - # run the query from the text area - df = run_query(sql_query) - # and display the results - st.dataframe(df) - - try: - # now if the dataset has supported shape try to display the bar or altair chart - if df.dtypes.shape[0] == 1 and show_charts: - # try barchart - st.bar_chart(df) - if df.dtypes.shape[0] == 2 and show_charts: - - # try to import altair charts - try: - import altair as alt - except ImportError: - raise MissingDependencyException( - "DLT Streamlit Helpers", - ["altair"], - "DLT Helpers for Streamlit should be run within a streamlit app." - ) - - # try altair - bar_chart = alt.Chart(df).mark_bar().encode( - x=f'{df.columns[1]}:Q', - y=alt.Y(f'{df.columns[0]}:N', sort='-x') - ) - st.altair_chart(bar_chart, use_container_width=True) - except Exception as ex: - st.error(f"Chart failed due to: {ex}") - except Exception as ex: - st.text("Exception when running query") - st.exception(ex) +# import os +# import tomlkit +# from tomlkit.container import Container as TomlContainer +# from typing import cast +# from copy import deepcopy + +# from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration +# from dlt.common.utils import dict_remove_nones_in_place + +# from dlt.pipeline import Pipeline +# from dlt.pipeline.typing import credentials_from_dict +# from dlt.pipeline.exceptions import MissingDependencyException, PipelineException +# from dlt.helpers.pandas import query_results_to_df, pd + +# try: +# import streamlit as st +# from streamlit import SECRETS_FILE_LOC, secrets +# except ImportError: +# raise MissingDependencyException("DLT Streamlit Helpers", ["streamlit"], "DLT Helpers for Streamlit should be run within a streamlit app.") + + +# def restore_pipeline() -> Pipeline: +# """Restores Pipeline instance and associated credentials from Streamlit secrets + +# Current implementation requires that pipeline working dir is available at the location saved in secrets. + +# Raises: +# PipelineBackupNotFound: Raised when pipeline backup is not available +# CannotRestorePipelineException: Raised when pipeline working dir is not found or invalid + +# Returns: +# Pipeline: Instance of pipeline with attached credentials +# """ +# if "dlt" not in secrets: +# raise PipelineException("You must backup pipeline to Streamlit first") +# dlt_cfg = secrets["dlt"] +# credentials = deepcopy(dict(dlt_cfg["destination"])) +# if "default_schema_name" in credentials: +# del credentials["default_schema_name"] +# credentials.update(dlt_cfg["credentials"]) +# pipeline = Pipeline(dlt_cfg["pipeline_name"]) +# pipeline.restore_pipeline(credentials_from_dict(credentials), dlt_cfg["working_dir"]) +# return pipeline + + +# def backup_pipeline(pipeline: Pipeline) -> None: +# """Backups pipeline state to the `secrets.toml` of the Streamlit app. + +# Pipeline credentials and working directory will be added to the Streamlit `secrets` file. This allows to access query the data loaded to the destination and +# access definitions of the inferred schemas. See `restore_pipeline` and `write_data_explorer_page` functions in the same module. + +# Args: +# pipeline (Pipeline): Pipeline instance, typically restored with `restore_pipeline` +# """ +# # save pipeline state to project .config +# # config_file_name = file_util.get_project_streamlit_file_path("config.toml") + +# # save credentials to secrets +# if os.path.isfile(SECRETS_FILE_LOC): +# with open(SECRETS_FILE_LOC, "r", encoding="utf-8") as f: +# # use whitespace preserving parser +# secrets_ = tomlkit.load(f) +# else: +# secrets_ = tomlkit.document() + +# # save general settings +# secrets_["dlt"] = { +# "working_dir": pipeline.working_dir, +# "pipeline_name": pipeline.pipeline_name +# } + +# # get client config +# # TODO: pipeline api v2 should provide a direct method to get configurations +# CONFIG: BaseConfiguration = pipeline._loader_instance.load_client_cls.CONFIG # type: ignore +# CREDENTIALS: CredentialsConfiguration = pipeline._loader_instance.load_client_cls.CREDENTIALS # type: ignore + +# # save client config +# # print(dict_remove_nones_in_place(CONFIG.as_dict(lowercase=False))) +# dlt_c = cast(TomlContainer, secrets_["dlt"]) +# dlt_c["destination"] = dict_remove_nones_in_place(dict(CONFIG)) +# dlt_c["credentials"] = dict_remove_nones_in_place(dict(CREDENTIALS)) + +# with open(SECRETS_FILE_LOC, "w", encoding="utf-8") as f: +# # use whitespace preserving parser +# tomlkit.dump(secrets_, f) + + +# def write_data_explorer_page(pipeline: Pipeline, schema_name: str = None, show_dlt_tables: bool = False, example_query: str = "", show_charts: bool = True) -> None: +# """Writes Streamlit app page with a schema and live data preview. + +# Args: +# pipeline (Pipeline): Pipeline instance to use. +# schema_name (str, optional): Name of the schema to display. If None, default schema is used. +# show_dlt_tables (bool, optional): Should show DLT internal tables. Defaults to False. +# example_query (str, optional): Example query to be displayed in the SQL Query box. +# show_charts (bool, optional): Should automatically show charts for the queries from SQL Query box. Defaults to True. + +# Raises: +# MissingDependencyException: Raised when a particular python dependency is not installed +# """ +# @st.experimental_memo(ttl=600) +# def run_query(query: str) -> pd.DataFrame: +# # dlt pipeline exposes configured sql client that (among others) let's you make queries against the warehouse +# with pipeline.sql_client(schema_name) as client: +# df = query_results_to_df(client, query) +# return df + +# if schema_name: +# schema = pipeline.get_schema(schema_name) +# else: +# schema = pipeline.get_default_schema() +# st.title(f"Available tables in {schema.name} schema") +# # st.text(schema.to_pretty_yaml()) + +# for table in schema.all_tables(with_dlt_tables=show_dlt_tables): +# table_name = table["name"] +# st.header(table_name) +# if "description" in table: +# st.text(table["description"]) +# if "parent" in table: +# st.text("Parent table: " + table["parent"]) + +# # table schema contains various hints (like clustering or partition options) that we do not want to show in basic view +# essentials_f = lambda c: {k:v for k, v in c.items() if k in ["name", "data_type", "nullable"]} + +# st.table(map(essentials_f, table["columns"].values())) +# # add a button that when pressed will show the full content of a table +# if st.button("SHOW DATA", key=table_name): +# st.text(f"Full {table_name} table content") +# st.dataframe(run_query(f"SELECT * FROM {table_name}")) + +# st.title("Run your query") +# sql_query = st.text_area("Enter your SQL query", value=example_query) +# if st.button("Run Query"): +# if sql_query: +# st.text("Results of a query") +# try: +# # run the query from the text area +# df = run_query(sql_query) +# # and display the results +# st.dataframe(df) + +# try: +# # now if the dataset has supported shape try to display the bar or altair chart +# if df.dtypes.shape[0] == 1 and show_charts: +# # try barchart +# st.bar_chart(df) +# if df.dtypes.shape[0] == 2 and show_charts: + +# # try to import altair charts +# try: +# import altair as alt +# except ImportError: +# raise MissingDependencyException( +# "DLT Streamlit Helpers", +# ["altair"], +# "DLT Helpers for Streamlit should be run within a streamlit app." +# ) + +# # try altair +# bar_chart = alt.Chart(df).mark_bar().encode( +# x=f'{df.columns[1]}:Q', +# y=alt.Y(f'{df.columns[0]}:N', sort='-x') +# ) +# st.altair_chart(bar_chart, use_container_width=True) +# except Exception as ex: +# st.error(f"Chart failed due to: {ex}") +# except Exception as ex: +# st.text("Exception when running query") +# st.exception(ex) diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 4954343f02..4d3ee8a4ba 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -47,7 +47,7 @@ def __init__(self, collector: CollectorRegistry = REGISTRY, schema_storage: Sche # setup storages self.create_storages() # create schema storage with give type - self.schema_storage = schema_storage or SchemaStorage(self.config.schema_storage_config, makedirs=True) + self.schema_storage = schema_storage or SchemaStorage(self.config._schema_storage_config, makedirs=True) try: self.create_gauges(collector) except ValueError as v: @@ -64,9 +64,9 @@ def create_gauges(registry: CollectorRegistry) -> None: def create_storages(self) -> None: # pass initial normalize storage config embedded in normalize config - self.normalize_storage = NormalizeStorage(True, config=self.config.normalize_storage_config) + self.normalize_storage = NormalizeStorage(True, config=self.config._normalize_storage_config) # normalize saves in preferred format but can read all supported formats - self.load_storage = LoadStorage(True, self.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, config=self.config.load_storage_config) + self.load_storage = LoadStorage(True, self.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, config=self.config._load_storage_config) @staticmethod diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 043593fb60..ee30fa952f 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -30,11 +30,10 @@ def pipeline( if context.is_activated(): return cast(Pipeline, context.pipeline()) - print(kwargs["_last_dlt_config"].pipeline_name) - print(kwargs["_last_dlt_config"].runtime.log_level) # if working_dir not provided use temp folder if not working_dir: working_dir = get_default_working_dir() + destination = DestinationReference.from_name(destination) # create new pipeline instance p = Pipeline(pipeline_name, working_dir, pipeline_secret, destination, dataset_name, import_schema_path, export_schema_path, always_drop_pipeline, False, kwargs["runtime"]) diff --git a/dlt/pipeline/exceptions.py b/dlt/pipeline/exceptions.py index 690ed9023c..5655243214 100644 --- a/dlt/pipeline/exceptions.py +++ b/dlt/pipeline/exceptions.py @@ -27,32 +27,16 @@ def _get_msg(self, appendix: str) -> str: def _to_pip_install(self) -> str: return "\n".join([f"pip install {d}" for d in self.dependencies]) - -# class NoPipelineException(PipelineException): -# def __init__(self) -> None: -# super().__init__("Please create or restore pipeline before using this function") - - class PipelineConfigMissing(PipelineException): - def __init__(self, config_elem: str, step: TPipelineStep, help: str = None) -> None: + def __init__(self, config_elem: str, step: TPipelineStep, _help: str = None) -> None: self.config_elem = config_elem self.step = step msg = f"Configuration element {config_elem} was not provided and {step} step cannot be executed" - if help: - msg += f"\n{help}\n" + if _help: + msg += f"\n{_help}\n" super().__init__(msg) -# class PipelineConfiguredException(PipelineException): -# def __init__(self, f_name: str) -> None: -# super().__init__(f"{f_name} cannot be called on already configured or restored pipeline.") - - -# class InvalidPipelineContextException(PipelineException): -# def __init__(self) -> None: -# super().__init__("There may be just one active pipeline in single python process. To activate current pipeline call `activate` method") - - class CannotRestorePipelineException(PipelineException): def __init__(self, pipeline_name: str, working_dir: str, reason: str) -> None: msg = f"Pipeline with name {pipeline_name} in working directory {working_dir} could not be restored: {reason}" @@ -70,7 +54,3 @@ def __init__(self, step: TPipelineStep, exception: BaseException, run_metrics: T self.exception = exception self.run_metrics = run_metrics super().__init__(f"Pipeline execution failed at stage {step} with exception:\n\n{type(exception)}\n{exception}") - - -# class CannotApplyHintsToManyResources(ArgumentsOverloadException): -# pass diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 178cbbbc29..a633ebf49f 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -214,7 +214,7 @@ def item_to_source(data_item: Any) -> DltSource: if schema: data_item.schema = schema # try to apply hints to resources - resources = data_item.resources + resources = data_item.resources.values() for r in resources: apply_hint_args(r) return data_item @@ -385,7 +385,7 @@ def sql_client(self, schema_name: str = None) -> SqlClientBase[Any]: if isinstance(client, SqlJobClientBase): return client.sql_client else: - raise SqlClientNotAvailable(self.destination.name()) + raise SqlClientNotAvailable(self.destination.__name__) def _get_normalize_storage(self) -> NormalizeStorage: return NormalizeStorage(True, self._normalize_storage_config) @@ -567,8 +567,6 @@ def _managed_state(self) -> Iterator[TPipelineState]: # load state from storage to be merged with pipeline changes, currently we assume no parallel changes # compare backup and new state, save only if different backup_state = self._get_state() - print(state) - print(backup_state) new_state = json.dumps(state, sort_keys=True) old_state = json.dumps(backup_state, sort_keys=True) # persist old state diff --git a/examples/google_drive_csv.py b/examples/google_drive_csv.py index 2b2f120e8e..36f129ea7a 100644 --- a/examples/google_drive_csv.py +++ b/examples/google_drive_csv.py @@ -62,8 +62,7 @@ def download_csv_as_json(file_id: str, csv_options: StrAny = None) -> Iterator[D # SCHEMA CREATION data_schema = None - # data_schema_file_path = f"/Users/adrian/PycharmProjects/sv/dlt/examples/schemas/inferred_drive_csv_{file_id}_schema.yml" - data_schema_file_path = f"examples/schemas/inferred_drive_csv_{file_id}_schema.yml" + data_schema_file_path = f"examples/schemas/inferred_drive_csv_{file_id}.schema.yml" credentials = GCPPipelineCredentials.from_services_file(gcp_credential_json_file_path, schema_prefix) diff --git a/mypy.ini b/mypy.ini index 33f3f0824c..f3eebe74e5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -10,6 +10,7 @@ check_untyped_defs=true warn_return_any=true namespace_packages=true warn_unused_ignores=true +enable_incomplete_features=true ;disallow_any_generics=false diff --git a/poetry.lock b/poetry.lock index 51dfa3b3f7..08de21b649 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3,7 +3,7 @@ name = "agate" version = "1.6.3" description = "A data analysis library that is optimized for humans instead of machines." category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -84,7 +84,7 @@ name = "babel" version = "2.10.3" description = "Internationalization utilities" category = "main" -optional = true +optional = false python-versions = ">=3.6" [package.dependencies] @@ -109,38 +109,6 @@ test = ["coverage (>=4.5.4)", "fixtures (>=3.0.0)", "flake8 (>=4.0.0)", "stestr toml = ["toml"] yaml = ["pyyaml"] -[[package]] -name = "boto3" -version = "1.24.76" -description = "The AWS SDK for Python" -category = "main" -optional = true -python-versions = ">= 3.7" - -[package.dependencies] -botocore = ">=1.27.76,<1.28.0" -jmespath = ">=0.7.1,<2.0.0" -s3transfer = ">=0.6.0,<0.7.0" - -[package.extras] -crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] - -[[package]] -name = "botocore" -version = "1.27.76" -description = "Low-level, data-driven core of boto 3." -category = "main" -optional = true -python-versions = ">= 3.7" - -[package.dependencies] -jmespath = ">=0.7.1,<2.0.0" -python-dateutil = ">=2.1,<3.0.0" -urllib3 = ">=1.25.4,<1.27" - -[package.extras] -crt = ["awscrt (==0.14.0)"] - [[package]] name = "cachetools" version = "5.2.0" @@ -162,7 +130,7 @@ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -184,7 +152,7 @@ name = "click" version = "8.1.3" description = "Composable command line interface toolkit" category = "main" -optional = true +optional = false python-versions = ">=3.7" [package.dependencies] @@ -198,29 +166,13 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -[[package]] -name = "dbt-bigquery" -version = "1.0.0" -description = "The BigQuery adapter plugin for dbt" -category = "main" -optional = true -python-versions = ">=3.7" - -[package.dependencies] -dbt-core = ">=1.0.0,<1.1.0" -google-api-core = ">=1.16.0,<3" -google-cloud-bigquery = ">=1.25.0,<3" -google-cloud-core = ">=1.3.0,<3" -googleapis-common-protos = ">=1.6.0,<2" -protobuf = ">=3.13.0,<4" - [[package]] name = "dbt-core" -version = "1.0.6" +version = "1.1.2" description = "With dbt, data analysts and engineers can build analytics the way engineers build applications." category = "main" -optional = true -python-versions = ">=3.7" +optional = false +python-versions = ">=3.7.2" [package.dependencies] agate = ">=1.6,<1.6.4" @@ -228,7 +180,7 @@ cffi = ">=1.9,<2.0.0" click = ">=7.0,<9" colorama = ">=0.3.9,<0.4.5" dbt-extractor = ">=0.4.1,<0.5.0" -hologram = "0.0.14" +hologram = ">=0.0.14,<=0.0.15" idna = ">=2.5,<4" isodate = ">=0.6,<0.7" Jinja2 = "2.11.3" @@ -236,11 +188,11 @@ logbook = ">=1.5,<1.6" MarkupSafe = ">=0.23,<2.1" mashumaro = "2.9" minimal-snowplow-tracker = "0.0.2" -networkx = ">=2.3,<3" +networkx = ">=2.3,<2.8.4" packaging = ">=20.9,<22.0" requests = "<3.0.0" sqlparse = ">=0.2.3,<0.5" -typing-extensions = ">=3.7.4,<3.11" +typing-extensions = ">=3.7.4" werkzeug = ">=1,<3" [[package]] @@ -248,34 +200,9 @@ name = "dbt-extractor" version = "0.4.1" description = "A tool to analyze and extract information from Jinja used in dbt projects." category = "main" -optional = true +optional = false python-versions = ">=3.6.1" -[[package]] -name = "dbt-postgres" -version = "1.0.6" -description = "The postgres adpter plugin for dbt (data build tool)" -category = "main" -optional = true -python-versions = ">=3.7" - -[package.dependencies] -dbt-core = "1.0.6" -psycopg2-binary = ">=2.8,<3.0" - -[[package]] -name = "dbt-redshift" -version = "1.0.1" -description = "The Redshift adapter plugin for dbt" -category = "main" -optional = true -python-versions = ">=3.7" - -[package.dependencies] -boto3 = ">=1.4.4,<2.0.0" -dbt-core = ">=1.0.0,<1.1.0" -dbt-postgres = ">=1.0.0,<1.1.0" - [[package]] name = "decopatch" version = "1.4.10" @@ -391,7 +318,7 @@ name = "future" version = "0.18.2" description = "Clean single-source support for Python 3 and 2" category = "main" -optional = true +optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" [[package]] @@ -602,14 +529,14 @@ dev = ["jinja2 (>=3.0.0,<3.1.0)", "towncrier (>=21,<22)", "sphinx-rtd-theme (>=0 [[package]] name = "hologram" -version = "0.0.14" +version = "0.0.15" description = "JSON schema generation from dataclasses" category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] -jsonschema = ">=3.0,<3.2" +jsonschema = ">=3.0,<4.0" python-dateutil = ">=2.8,<2.9" [[package]] @@ -624,7 +551,7 @@ python-versions = ">=3.5" name = "importlib-metadata" version = "4.12.0" description = "Read metadata from Python packages" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -649,7 +576,7 @@ name = "isodate" version = "0.6.1" description = "An ISO 8601 date/time/duration parser and formatter" category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -660,7 +587,7 @@ name = "jinja2" version = "2.11.3" description = "A very fast and expressive template engine." category = "main" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.dependencies] @@ -669,14 +596,6 @@ MarkupSafe = ">=0.23" [package.extras] i18n = ["Babel (>=0.8)"] -[[package]] -name = "jmespath" -version = "1.0.1" -description = "JSON Matching Expressions" -category = "main" -optional = true -python-versions = ">=3.7" - [[package]] name = "json-logging" version = "1.4.1rc0" @@ -695,27 +614,27 @@ python-versions = ">=3.6" [[package]] name = "jsonschema" -version = "3.1.1" +version = "3.2.0" description = "An implementation of JSON Schema validation for Python" category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] attrs = ">=17.4.0" -importlib-metadata = "*" pyrsistent = ">=0.14.0" six = ">=1.11.0" [package.extras] format = ["idna", "jsonpointer (>1.13)", "rfc3987", "strict-rfc3339", "webcolors"] +format_nongpl = ["idna", "jsonpointer (>1.13)", "webcolors", "rfc3986-validator (>0.1.0)", "rfc3339-validator"] [[package]] name = "leather" version = "0.3.4" description = "Python charting for 80% of humans." category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -726,7 +645,7 @@ name = "logbook" version = "1.5.3" description = "A logging replacement for Python" category = "main" -optional = true +optional = false python-versions = "*" [package.extras] @@ -753,7 +672,7 @@ name = "markupsafe" version = "2.0.1" description = "Safely add untrusted strings to HTML/XML markup." category = "main" -optional = true +optional = false python-versions = ">=3.6" [[package]] @@ -761,7 +680,7 @@ name = "mashumaro" version = "2.9" description = "Fast serialization framework on top of dataclasses" category = "main" -optional = true +optional = false python-versions = ">=3.6" [package.dependencies] @@ -782,7 +701,7 @@ name = "minimal-snowplow-tracker" version = "0.0.2" description = "A minimal snowplow event tracker for Python. Add analytics to your Python and Django apps, webapps and games" category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -794,16 +713,16 @@ name = "msgpack" version = "1.0.4" description = "MessagePack serializer" category = "main" -optional = true +optional = false python-versions = "*" [[package]] name = "mypy" -version = "0.971" +version = "0.982" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] mypy-extensions = ">=0.4.3" @@ -811,9 +730,9 @@ tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typing-extensions = ">=3.10" [package.extras] -reports = ["lxml"] -python2 = ["typed-ast (>=1.4.0,<2)"] dmypy = ["psutil (>=4.0)"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] [[package]] name = "mypy-extensions" @@ -837,16 +756,16 @@ icu = ["PyICU (>=1.0.0)"] [[package]] name = "networkx" -version = "2.8.6" +version = "2.8.3" description = "Python package for creating and manipulating graphs and networks" category = "main" -optional = true +optional = false python-versions = ">=3.8" [package.extras] default = ["numpy (>=1.19)", "scipy (>=1.8)", "matplotlib (>=3.4)", "pandas (>=1.3)"] -developer = ["pre-commit (>=2.20)", "mypy (>=0.961)"] -doc = ["sphinx (>=5)", "pydata-sphinx-theme (>=0.9)", "sphinx-gallery (>=0.10)", "numpydoc (>=1.4)", "pillow (>=9.1)", "nb2plots (>=0.6)", "texext (>=0.6.6)"] +developer = ["pre-commit (>=2.19)", "mypy (>=0.960)"] +doc = ["sphinx (>=4.5)", "pydata-sphinx-theme (>=0.8.1)", "sphinx-gallery (>=0.10)", "numpydoc (>=1.3)", "pillow (>=9.1)", "nb2plots (>=0.6)", "texext (>=0.6.6)"] extra = ["lxml (>=4.6)", "pygraphviz (>=1.9)", "pydot (>=1.4.2)", "sympy (>=1.10)"] test = ["pytest (>=7.1)", "pytest-cov (>=3.0)", "codecov (>=2.1)"] @@ -874,7 +793,7 @@ name = "parsedatetime" version = "2.4" description = "Parse human-readable date/time text." category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -1027,7 +946,7 @@ name = "pycparser" version = "2.21" description = "C parser in Python" category = "main" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] @@ -1051,10 +970,10 @@ diagrams = ["railroad-diagrams", "jinja2"] [[package]] name = "pyrsistent" -version = "0.18.1" +version = "0.19.1" description = "Persistent/Functional/Immutable data structures" category = "main" -optional = true +optional = false python-versions = ">=3.7" [[package]] @@ -1143,7 +1062,7 @@ name = "python-slugify" version = "6.1.2" description = "A Python slugify application that also handles Unicode" category = "main" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" [package.dependencies] @@ -1157,15 +1076,15 @@ name = "pytimeparse" version = "1.1.8" description = "Time expression parser" category = "main" -optional = true +optional = false python-versions = "*" [[package]] name = "pytz" -version = "2022.2.1" +version = "2022.5" description = "World timezone definitions, modern and historical" category = "main" -optional = true +optional = false python-versions = "*" [[package]] @@ -1213,20 +1132,6 @@ python-versions = ">=3.6,<4" [package.dependencies] pyasn1 = ">=0.1.3" -[[package]] -name = "s3transfer" -version = "0.6.0" -description = "An Amazon S3 Transfer Manager" -category = "main" -optional = true -python-versions = ">= 3.7" - -[package.dependencies] -botocore = ">=1.12.36,<2.0a.0" - -[package.extras] -crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] - [[package]] name = "semver" version = "2.13.0" @@ -1293,10 +1198,10 @@ python-versions = ">=3.6" [[package]] name = "sqlparse" -version = "0.4.2" +version = "0.4.3" description = "A non-validating SQL parser." category = "main" -optional = true +optional = false python-versions = ">=3.5" [[package]] @@ -1315,7 +1220,7 @@ name = "text-unidecode" version = "1.3" description = "The most basic Text::Unidecode port" category = "main" -optional = true +optional = false python-versions = "*" [[package]] @@ -1403,11 +1308,11 @@ python-versions = "*" [[package]] name = "typing-extensions" -version = "3.10.0.2" -description = "Backported and Experimental Type Hints for Python 3.5+" +version = "4.4.0" +description = "Backported and Experimental Type Hints for Python 3.7+" category = "main" optional = false -python-versions = "*" +python-versions = ">=3.7" [[package]] name = "tzdata" @@ -1435,7 +1340,7 @@ name = "werkzeug" version = "2.1.2" description = "The comprehensive WSGI web application library." category = "main" -optional = true +optional = false python-versions = ">=3.7" [package.extras] @@ -1445,7 +1350,7 @@ watchdog = ["watchdog"] name = "zipp" version = "3.8.1" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -1455,7 +1360,6 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [extras] bigquery = ["grpcio", "google-cloud-bigquery", "google-cloud-bigquery-storage", "pyarrow"] -dbt = ["dbt-core", "GitPython", "dbt-redshift", "dbt-bigquery"] gcp = ["grpcio", "google-cloud-bigquery", "google-cloud-bigquery-storage", "pyarrow"] postgres = ["psycopg2-binary", "psycopg2cffi"] redshift = ["psycopg2-binary", "psycopg2cffi"] @@ -1463,7 +1367,7 @@ redshift = ["psycopg2-binary", "psycopg2cffi"] [metadata] lock-version = "1.1" python-versions = "^3.8,<3.11" -content-hash = "24bd34ae0bdd70f265ba4fb25f28b084fa55f6f8c9563f482fd79c1d0d695563" +content-hash = "f3ce0afb16174d4f0b4e297adba698c13078f3f3cfee6526b776b8096720c33b" [metadata.files] agate = [ @@ -1486,8 +1390,6 @@ bandit = [ {file = "bandit-1.7.4-py3-none-any.whl", hash = "sha256:412d3f259dab4077d0e7f0c11f50f650cc7d10db905d98f6520a95a18049658a"}, {file = "bandit-1.7.4.tar.gz", hash = "sha256:2d63a8c573417bae338962d4b9b06fbc6080f74ecd955a092849e1e65c717bd2"}, ] -boto3 = [] -botocore = [] cachetools = [ {file = "cachetools-5.2.0-py3-none-any.whl", hash = "sha256:f9f17d2aec496a9aa6b76f53e3b614c965223c061982d434d160f930c698a9db"}, {file = "cachetools-5.2.0.tar.gz", hash = "sha256:6a94c6402995a99c3970cc7e4884bb60b4a8639938157eeed436098bf9831757"}, @@ -1568,14 +1470,7 @@ colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] -dbt-bigquery = [ - {file = "dbt-bigquery-1.0.0.tar.gz", hash = "sha256:e22442f00fcec155dcbfe8be351a11c35913fb6edd11bd5e52fafc3218abd12e"}, - {file = "dbt_bigquery-1.0.0-py3-none-any.whl", hash = "sha256:48778c89a37dd866ffd3718bf6b78e1139b7fb4cc0377f2feaa95e10dc3ce9c2"}, -] -dbt-core = [ - {file = "dbt-core-1.0.6.tar.gz", hash = "sha256:5155bc4e81aba9df1a9a183205c0a240a3ec08d4fb9377df4f0d4d4b96268be1"}, - {file = "dbt_core-1.0.6-py3-none-any.whl", hash = "sha256:20e8e4fdd9ad08a25b3fb7020ffbdfd3b9aa6339a63a3d125f3f6d3edc2605f2"}, -] +dbt-core = [] dbt-extractor = [ {file = "dbt_extractor-0.4.1-cp36-abi3-macosx_10_7_x86_64.whl", hash = "sha256:4dc715bd740e418d8dc1dd418fea508e79208a24cf5ab110b0092a3cbe96bf71"}, {file = "dbt_extractor-0.4.1-cp36-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:bc9e0050e3a2f4ea9fe58e8794bc808e6709a0c688ed710fc7c5b6ef3e5623ec"}, @@ -1594,14 +1489,6 @@ dbt-extractor = [ {file = "dbt_extractor-0.4.1-cp36-abi3-win_amd64.whl", hash = "sha256:35265a0ae0a250623b0c2e3308b2738dc8212e40e0aa88407849e9ea090bb312"}, {file = "dbt_extractor-0.4.1.tar.gz", hash = "sha256:75b1c665699ec0f1ffce1ba3d776f7dfce802156f22e70a7b9c8f0b4d7e80f42"}, ] -dbt-postgres = [ - {file = "dbt-postgres-1.0.6.tar.gz", hash = "sha256:f560ab7178e19990b9d1e5d4787a9f5c7104708a0bf09b8693548723b1d9dfc2"}, - {file = "dbt_postgres-1.0.6-py3-none-any.whl", hash = "sha256:3cf9d76d87768f7e398c86ade6c5be7fa1a3984384beb3a63a7c0b2008e6aec8"}, -] -dbt-redshift = [ - {file = "dbt-redshift-1.0.1.tar.gz", hash = "sha256:1e45d2948313a588d54d7b59354e7850a969cf2aafb4d3581f3a733cb0170e68"}, - {file = "dbt_redshift-1.0.1-py3-none-any.whl", hash = "sha256:1e5219d67c6c7a52235c46c7ca559b118ac7a5e1e62e6b3138eaa1cb67597751"}, -] decopatch = [ {file = "decopatch-1.4.10-py2.py3-none-any.whl", hash = "sha256:e151f7f93de2b1b3fd3f3272dcc7cefd1a69f68ec1c2d8e288ecd9deb36dc5f7"}, {file = "decopatch-1.4.10.tar.gz", hash = "sha256:957f49c93f4150182c23f8fb51d13bb3213e0f17a79e09c8cca7057598b55720"}, @@ -1701,10 +1588,7 @@ grpcio-status = [ {file = "grpcio_status-1.43.0-py3-none-any.whl", hash = "sha256:9036b24f5769adafdc3e91d9434c20e9ede0b30f50cc6bff105c0f414bb9e0e0"}, ] hexbytes = [] -hologram = [ - {file = "hologram-0.0.14-py3-none-any.whl", hash = "sha256:2911b59115bebd0504eb089532e494fa22ac704989afe41371c5361780433bfe"}, - {file = "hologram-0.0.14.tar.gz", hash = "sha256:fd67bd069e4681e1d2a447df976c65060d7a90fee7f6b84d133fd9958db074ec"}, -] +hologram = [] idna = [] importlib-metadata = [ {file = "importlib_metadata-4.12.0-py3-none-any.whl", hash = "sha256:7401a975809ea1fdc658c3aa4f78cc2195a0e019c5cbc4c06122884e9ae80c23"}, @@ -1722,10 +1606,6 @@ jinja2 = [ {file = "Jinja2-2.11.3-py2.py3-none-any.whl", hash = "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419"}, {file = "Jinja2-2.11.3.tar.gz", hash = "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6"}, ] -jmespath = [ - {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, - {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, -] json-logging = [ {file = "json-logging-1.4.1rc0.tar.gz", hash = "sha256:381e00495bbd619d09c8c3d1fdd72c843f7045797ab63b42cfec5f7961e5b3f6"}, {file = "json_logging-1.4.1rc0-py2.py3-none-any.whl", hash = "sha256:2b787c28f31fb4d8aabac16ac3816326031d92dd054bdabc9bbe68eb10864f77"}, @@ -1734,10 +1614,7 @@ jsonlines = [ {file = "jsonlines-2.0.0-py3-none-any.whl", hash = "sha256:bfb043d4e25fd894dca67b1f2adf014e493cb65d0f18b3a74a98bfcd97c3d983"}, {file = "jsonlines-2.0.0.tar.gz", hash = "sha256:6fdd03104c9a421a1ba587a121aaac743bf02d8f87fa9cdaa3b852249a241fe8"}, ] -jsonschema = [ - {file = "jsonschema-3.1.1-py2.py3-none-any.whl", hash = "sha256:94c0a13b4a0616458b42529091624e66700a17f847453e52279e35509a5b7631"}, - {file = "jsonschema-3.1.1.tar.gz", hash = "sha256:2fa0684276b6333ff3c0b1b27081f4b2305f0a36cf702a23db50edb141893c3f"}, -] +jsonschema = [] leather = [ {file = "leather-0.3.4-py2.py3-none-any.whl", hash = "sha256:5e741daee96e9f1e9e06081b8c8a10c4ac199301a0564cdd99b09df15b4603d2"}, {file = "leather-0.3.4.tar.gz", hash = "sha256:b43e21c8fa46b2679de8449f4d953c06418666dc058ce41055ee8a8d3bb40918"}, @@ -1893,7 +1770,10 @@ mypy-extensions = [ {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] natsort = [] -networkx = [] +networkx = [ + {file = "networkx-2.8.3-py3-none-any.whl", hash = "sha256:f151edac6f9b0cf11fecce93e236ac22b499bb9ff8d6f8393b9fef5ad09506cc"}, + {file = "networkx-2.8.3.tar.gz", hash = "sha256:67fab04a955a73eb660fe7bf281b6fa71a003bc6e23a92d2f6227654c5223dbe"}, +] numpy = [] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, @@ -2044,29 +1924,7 @@ pyparsing = [ {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, ] -pyrsistent = [ - {file = "pyrsistent-0.18.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df46c854f490f81210870e509818b729db4488e1f30f2a1ce1698b2295a878d1"}, - {file = "pyrsistent-0.18.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d45866ececf4a5fff8742c25722da6d4c9e180daa7b405dc0a2a2790d668c26"}, - {file = "pyrsistent-0.18.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ed6784ceac462a7d6fcb7e9b663e93b9a6fb373b7f43594f9ff68875788e01e"}, - {file = "pyrsistent-0.18.1-cp310-cp310-win32.whl", hash = "sha256:e4f3149fd5eb9b285d6bfb54d2e5173f6a116fe19172686797c056672689daf6"}, - {file = "pyrsistent-0.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:636ce2dc235046ccd3d8c56a7ad54e99d5c1cd0ef07d9ae847306c91d11b5fec"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e92a52c166426efbe0d1ec1332ee9119b6d32fc1f0bbfd55d5c1088070e7fc1b"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7a096646eab884bf8bed965bad63ea327e0d0c38989fc83c5ea7b8a87037bfc"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cdfd2c361b8a8e5d9499b9082b501c452ade8bbf42aef97ea04854f4a3f43b22"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-win32.whl", hash = "sha256:7ec335fc998faa4febe75cc5268a9eac0478b3f681602c1f27befaf2a1abe1d8"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-win_amd64.whl", hash = "sha256:6455fc599df93d1f60e1c5c4fe471499f08d190d57eca040c0ea182301321286"}, - {file = "pyrsistent-0.18.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:fd8da6d0124efa2f67d86fa70c851022f87c98e205f0594e1fae044e7119a5a6"}, - {file = "pyrsistent-0.18.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bfe2388663fd18bd8ce7db2c91c7400bf3e1a9e8bd7d63bf7e77d39051b85ec"}, - {file = "pyrsistent-0.18.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e3e1fcc45199df76053026a51cc59ab2ea3fc7c094c6627e93b7b44cdae2c8c"}, - {file = "pyrsistent-0.18.1-cp38-cp38-win32.whl", hash = "sha256:b568f35ad53a7b07ed9b1b2bae09eb15cdd671a5ba5d2c66caee40dbf91c68ca"}, - {file = "pyrsistent-0.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:d1b96547410f76078eaf66d282ddca2e4baae8964364abb4f4dcdde855cd123a"}, - {file = "pyrsistent-0.18.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f87cc2863ef33c709e237d4b5f4502a62a00fab450c9e020892e8e2ede5847f5"}, - {file = "pyrsistent-0.18.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bc66318fb7ee012071b2792024564973ecc80e9522842eb4e17743604b5e045"}, - {file = "pyrsistent-0.18.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:914474c9f1d93080338ace89cb2acee74f4f666fb0424896fcfb8d86058bf17c"}, - {file = "pyrsistent-0.18.1-cp39-cp39-win32.whl", hash = "sha256:1b34eedd6812bf4d33814fca1b66005805d3640ce53140ab8bbb1e2651b0d9bc"}, - {file = "pyrsistent-0.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:e24a828f57e0c337c8d8bb9f6b12f09dfdf0273da25fda9e314f0b684b415a07"}, - {file = "pyrsistent-0.18.1.tar.gz", hash = "sha256:d4d61f8b993a7255ba714df3aca52700f8125289f84f704cf80916517c46eb96"}, -] +pyrsistent = [] pytest = [ {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, @@ -2140,10 +1998,6 @@ requests = [ {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, ] rsa = [] -s3transfer = [ - {file = "s3transfer-0.6.0-py3-none-any.whl", hash = "sha256:06176b74f3a15f61f1b4f25a1fc29a4429040b7647133a463da8fa5bd28d5ecd"}, - {file = "s3transfer-0.6.0.tar.gz", hash = "sha256:2ed07d3866f523cc561bf4a00fc5535827981b117dd7876f036b0c1aca42c947"}, -] semver = [ {file = "semver-2.13.0-py2.py3-none-any.whl", hash = "sha256:ced8b23dceb22134307c1b8abfa523da14198793d9787ac838e70e29e77458d4"}, {file = "semver-2.13.0.tar.gz", hash = "sha256:fa0fe2722ee1c3f57eac478820c3a5ae2f624af8264cbdf9000c980ff7f75e3f"}, @@ -2220,10 +2074,7 @@ smmap = [ {file = "smmap-5.0.0-py3-none-any.whl", hash = "sha256:2aba19d6a040e78d8b09de5c57e96207b09ed71d8e55ce0959eeee6c8e190d94"}, {file = "smmap-5.0.0.tar.gz", hash = "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936"}, ] -sqlparse = [ - {file = "sqlparse-0.4.2-py3-none-any.whl", hash = "sha256:48719e356bb8b42991bdbb1e8b83223757b93789c00910a616a071910ca4a64d"}, - {file = "sqlparse-0.4.2.tar.gz", hash = "sha256:0c00730c74263a94e5a9919ade150dfc3b19c574389985446148402998287dae"}, -] +sqlparse = [] stevedore = [] text-unidecode = [ {file = "text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93"}, @@ -2248,11 +2099,7 @@ types-pyyaml = [] types-requests = [] types-simplejson = [] types-urllib3 = [] -typing-extensions = [ - {file = "typing_extensions-3.10.0.2-py2-none-any.whl", hash = "sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7"}, - {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"}, - {file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"}, -] +typing-extensions = [] tzdata = [] urllib3 = [] werkzeug = [ diff --git a/pyproject.toml b/pyproject.toml index 5bdb5c6726..467552a128 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,9 @@ google-cloud-bigquery-storage = {version = "^2.13.0", optional = true} pyarrow = {version = "^8.0.0", optional = true} GitPython = {version = "^3.1.26", optional = true} -dbt-core = {version = "1.0.6", optional = true} -dbt-redshift = {version = "1.0.1", optional = true} -dbt-bigquery = {version = "1.0.0", optional = true} +dbt-core = {version = ">=1.1.0,<1.2.0", optional = true} +dbt-redshift = {version = ">=1.0.0,<1.2.0", optional = true} +dbt-bigquery = {version = ">=1.0.0,<1.2.0", optional = true} tzdata = "^2022.1" tomlkit = "^0.11.3" asyncstdlib = "^3.10.5" @@ -58,7 +58,7 @@ pathvalidate = "^2.5.2" [tool.poetry.dev-dependencies] pytest = "^6.2.4" -mypy = "0.971" +mypy = "0.982" flake8 = "^5.0.0" bandit = "^1.7.0" flake8-bugbear = "^22.0.0" @@ -75,6 +75,7 @@ types-python-dateutil = "^2.8.15" flake8-tidy-imports = "^4.8.0" flake8-encodings = "^0.5.0" flake8-builtins = "^1.5.3" +typing-extensions = "^4.4.0" [tool.poetry.extras] dbt = ["dbt-core", "GitPython", "dbt-redshift", "dbt-bigquery"]