Skip to content

Commit

Permalink
Validate DB and table names
Browse files Browse the repository at this point in the history
  • Loading branch information
jond01 committed Sep 12, 2024
1 parent 667e64a commit dddcfaf
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 1 deletion.
4 changes: 4 additions & 0 deletions storey/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class TDEngineTypeError(TypeError):
pass


class TDEngineValueError(ValueError):
pass


class WindowBase:
def __init__(self, window, period, window_str):
self.window_millis = window
Expand Down
28 changes: 27 additions & 1 deletion storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import queue
import random
import re
import traceback
import uuid
from io import StringIO
Expand All @@ -33,7 +34,13 @@
import xxhash

from . import Driver
from .dtypes import Event, TDEngineTypeError, V3ioError, _TDEngineField
from .dtypes import (
Event,
TDEngineTypeError,
TDEngineValueError,
V3ioError,
_TDEngineField,
)
from .flow import Flow, _Batching, _split_path, _termination_obj
from .table import Table, _PersistJob
from .utils import stringify_key, url_to_file_system, wrap_event_for_serialization
Expand Down Expand Up @@ -808,6 +815,10 @@ class TDEngineTarget(_Batching, _Writer):
:type flush_after_seconds: int
"""

# https://docs.tdengine.com/reference/taos-sql/limit/
_DB_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,191}$")

def __init__(
self,
url: str,
Expand Down Expand Up @@ -879,9 +890,24 @@ def __init__(
self._user = user
self._password = password
self._database = database
self._validate_db_and_table_names()
self._tdengine_type_to_column_func = self._get_tdengine_type_to_column_func()
self._tdengine_type_to_tag_func = self._get_tdengine_type_to_tag_func()

def _validate_db_and_table_names(self) -> None:
"""Check the names match their pattern"""
if not self._database:
raise TDEngineValueError("TDEngine database must be set")
if not self._DB_NAME_PATTERN.fullmatch(self._database):
raise TDEngineValueError(f"TDEngine database '{self._database}' does not comply with the naming convention")

for table_name in (self._table, self._supertable):
if table_name:
if not self._TABLE_NAME_PATTERN.fullmatch(table_name):
raise TDEngineValueError(
f"TDEngine table name '{table_name}' does not comply with the naming convention"
)

@staticmethod
def _get_tdengine_type_to_column_func() -> dict[str, Callable[[list], "taosws.PyColumnView"]]:
import taosws
Expand Down
34 changes: 34 additions & 0 deletions tests/test_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import pytest

from storey.dtypes import TDEngineValueError
from storey.targets import TDEngineTarget


Expand All @@ -29,3 +34,32 @@ def test_columns_mapping_consistency() -> None:
else:
assert func.__name__.startswith(type_.lower())
assert func.__name__.endswith("_to_column")

@staticmethod
@pytest.mark.parametrize(
("database", "table", "supertable", "table_col", "tag_cols"),
[
(None, None, "my_super_tb", "pass_this_check", ["also_this_one"]),
("mydb", None, "my super tb", "pass_this_check", ["also_this_one"]),
("_db", "9table", None, None, None),
("_db", " cars", None, None, None),
],
)
def test_invalid_names(
database: Optional[str],
table: Optional[str],
supertable: Optional[str],
table_col: Optional[str],
tag_cols: Optional[list[str]],
) -> None:
with pytest.raises(TDEngineValueError):
TDEngineTarget(
url="taosws://root:taosdata@localhost:6041",
time_col="ts",
columns=["value"],
table_col=table_col,
tag_cols=tag_cols,
database=database,
table=table,
supertable=supertable,
)

0 comments on commit dddcfaf

Please sign in to comment.