Skip to content

Commit

Permalink
YDB FQ: fix tests after type renaming (ydb-platform#12589)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalyisaev2 committed Dec 16, 2024
1 parent 0c4f6d4 commit 9b690f9
Show file tree
Hide file tree
Showing 55 changed files with 372 additions and 354 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict
import functools

from ydb.library.yql.providers.generic.connector.api.common.data_source_pb2 import EDataSourceKind, EProtocol
from yql.essentials.providers.common.proto.gateways_config_pb2 import EGenericDataSourceKind, EGenericProtocol
from ydb.library.yql.providers.generic.connector.api.service.protos.connector_pb2 import EDateTimeFormat
from ydb.library.yql.providers.generic.connector.tests.utils.database import Database
from ydb.library.yql.providers.generic.connector.tests.utils.settings import GenericSettings
Expand All @@ -13,26 +13,26 @@
@dataclass
class BaseTestCase:
name_: str
data_source_kind: EDataSourceKind.ValueType
data_source_kind: EGenericDataSourceKind.ValueType
pragmas: Dict[str, str]
protocol: EProtocol
protocol: EGenericProtocol

@property
def name(self) -> str:
match self.data_source_kind:
case EDataSourceKind.CLICKHOUSE:
case EGenericDataSourceKind.CLICKHOUSE:
# ClickHouse has two kinds of network protocols: NATIVE and HTTP,
# so we append protocol name to the test case name
return f'{self.name_}_{EProtocol.Name(self.protocol)}'
case EDataSourceKind.MS_SQL_SERVER:
return f'{self.name_}_{EGenericProtocol.Name(self.protocol)}'
case EGenericDataSourceKind.MS_SQL_SERVER:
return self.name_
case EDataSourceKind.MYSQL:
case EGenericDataSourceKind.MYSQL:
return self.name_
case EDataSourceKind.ORACLE:
case EGenericDataSourceKind.ORACLE:
return self.name_
case EDataSourceKind.POSTGRESQL:
case EGenericDataSourceKind.POSTGRESQL:
return self.name_
case EDataSourceKind.YDB:
case EGenericDataSourceKind.YDB:
return self.name_
case _:
raise Exception(f'invalid data source: {self.data_source_kind}')
Expand All @@ -45,17 +45,17 @@ def database(self) -> Database:
'''
# FIXME: do not hardcode databases here
match self.data_source_kind:
case EDataSourceKind.CLICKHOUSE:
case EGenericDataSourceKind.CLICKHOUSE:
return Database(self.name, self.data_source_kind)
case EDataSourceKind.MS_SQL_SERVER:
case EGenericDataSourceKind.MS_SQL_SERVER:
return Database("master", self.data_source_kind)
case EDataSourceKind.MYSQL:
case EGenericDataSourceKind.MYSQL:
return Database("db", self.data_source_kind)
case EDataSourceKind.ORACLE:
case EGenericDataSourceKind.ORACLE:
return Database(self.name, self.data_source_kind)
case EDataSourceKind.POSTGRESQL:
case EGenericDataSourceKind.POSTGRESQL:
return Database(self.name, self.data_source_kind)
case EDataSourceKind.YDB:
case EGenericDataSourceKind.YDB:
return Database("local", self.data_source_kind)

@functools.cached_property
Expand All @@ -65,17 +65,17 @@ def table_name(self) -> str:
so we provide a random table name instead where necessary.
'''
match self.data_source_kind:
case EDataSourceKind.CLICKHOUSE:
case EGenericDataSourceKind.CLICKHOUSE:
return self.name_ # without protocol
case EDataSourceKind.MS_SQL_SERVER:
case EGenericDataSourceKind.MS_SQL_SERVER:
return self.name
case EDataSourceKind.MYSQL:
case EGenericDataSourceKind.MYSQL:
return self.name
case EDataSourceKind.ORACLE:
case EGenericDataSourceKind.ORACLE:
return self.name
case EDataSourceKind.POSTGRESQL:
case EGenericDataSourceKind.POSTGRESQL:
return 't' + make_random_string(8)
case EDataSourceKind.YDB:
case EGenericDataSourceKind.YDB:
return self.name
case _:
raise Exception(f'invalid data source: {self.data_source_kind}')
Expand All @@ -90,34 +90,34 @@ def pragmas_sql_string(self) -> str:
@property
def generic_settings(self) -> GenericSettings:
match self.data_source_kind:
case EDataSourceKind.CLICKHOUSE:
case EGenericDataSourceKind.CLICKHOUSE:
return GenericSettings(
date_time_format=EDateTimeFormat.YQL_FORMAT,
clickhouse_clusters=[
GenericSettings.ClickHouseCluster(database=self.database.name, protocol=EProtocol.NATIVE)
GenericSettings.ClickHouseCluster(database=self.database.name, protocol=EGenericProtocol.NATIVE)
],
)
case EDataSourceKind.MS_SQL_SERVER:
case EGenericDataSourceKind.MS_SQL_SERVER:
return GenericSettings(
date_time_format=EDateTimeFormat.YQL_FORMAT,
ms_sql_server_clusters=[GenericSettings.MsSQLServerCluster(database=self.database.name)],
)
case EDataSourceKind.MYSQL:
case EGenericDataSourceKind.MYSQL:
return GenericSettings(
date_time_format=EDateTimeFormat.YQL_FORMAT,
mysql_clusters=[GenericSettings.MySQLCluster(database=self.database.name)],
)
case EDataSourceKind.ORACLE:
case EGenericDataSourceKind.ORACLE:
return GenericSettings(
date_time_format=EDateTimeFormat.YQL_FORMAT,
oracle_clusters=[GenericSettings.OracleCluster(database=self.database.name, service_name=None)],
)
case EDataSourceKind.POSTGRESQL:
case EGenericDataSourceKind.POSTGRESQL:
return GenericSettings(
date_time_format=EDateTimeFormat.YQL_FORMAT,
postgresql_clusters=[GenericSettings.PostgreSQLCluster(database=self.database.name, schema=None)],
)
case EDataSourceKind.YDB:
case EGenericDataSourceKind.YDB:
return GenericSettings(
date_time_format=EDateTimeFormat.YQL_FORMAT,
ydb_clusters=[GenericSettings.YdbCluster(database=self.database.name)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass

from ydb.library.yql.providers.generic.connector.tests.utils.settings import Settings
from ydb.library.yql.providers.generic.connector.api.common.data_source_pb2 import EDataSourceKind, EProtocol
from yql.essentials.providers.common.proto.gateways_config_pb2 import EGenericDataSourceKind, EGenericProtocol
from ydb.library.yql.providers.generic.connector.tests.common_test_cases.base import BaseTestCase
from ydb.library.yql.providers.generic.connector.tests.utils.settings import GenericSettings

Expand All @@ -16,14 +16,14 @@ def generic_settings(self) -> GenericSettings:
gs = super().generic_settings

# Overload setting for MySQL database
if self.data_source_kind == EDataSourceKind.MYSQL:
if self.data_source_kind == EGenericDataSourceKind.MYSQL:
for cluster in gs.mysql_clusters:
cluster.database = "missing_database"
for cluster in gs.oracle_clusters:
if self.service_name is not None:
cluster.service_name = self.service_name

if self.data_source_kind == EDataSourceKind.MS_SQL_SERVER:
if self.data_source_kind == EGenericDataSourceKind.MS_SQL_SERVER:
for cluster in gs.ms_sql_server_clusters:
cluster.database = "missing_database"

Expand All @@ -36,12 +36,12 @@ class Factory:
def __init__(self, ss: Settings):
self.ss = ss

def make_test_cases(self, data_source_kind: EDataSourceKind) -> List[TestCase]:
def make_test_cases(self, data_source_kind: EGenericDataSourceKind) -> List[TestCase]:
return [
TestCase(
name_="missing_database",
data_source_kind=data_source_kind,
protocol=EProtocol.NATIVE,
protocol=EGenericProtocol.NATIVE,
pragmas=dict(),
service_name=self.ss.oracle.service_name if self.ss.oracle is not None else None,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Sequence

from ydb.library.yql.providers.generic.connector.tests.utils.settings import Settings
from ydb.library.yql.providers.generic.connector.api.common.data_source_pb2 import EDataSourceKind, EProtocol
from yql.essentials.providers.common.proto.gateways_config_pb2 import EGenericDataSourceKind, EGenericProtocol
from ydb.library.yql.providers.generic.connector.tests.common_test_cases.base import BaseTestCase
from ydb.library.yql.providers.generic.connector.tests.utils.settings import GenericSettings

Expand Down Expand Up @@ -31,15 +31,15 @@ class Factory:
def __init__(self, ss: Settings):
self.ss = ss

def make_test_cases(self, data_source_kind: EDataSourceKind) -> List[TestCase]:
def make_test_cases(self, data_source_kind: EGenericDataSourceKind) -> List[TestCase]:
test_cases = []

test_case_name = 'missing_table'

test_case = TestCase(
name_=test_case_name,
data_source_kind=data_source_kind,
protocol=EProtocol.NATIVE,
protocol=EGenericProtocol.NATIVE,
pragmas=dict(),
service_name=self.ss.oracle.service_name if self.ss.oracle is not None else None,
)
Expand Down
Loading

0 comments on commit 9b690f9

Please sign in to comment.