From b6647c0931a173f90d817263da5db733f3bf93d5 Mon Sep 17 00:00:00 2001 From: Bryan Forbes Date: Mon, 17 Dec 2018 16:11:18 -0600 Subject: [PATCH] Add more selectable typings (#33) --- sqlalchemy-stubs/sql/elements.pyi | 11 +- sqlalchemy-stubs/sql/schema.pyi | 9 +- sqlalchemy-stubs/sql/selectable.pyi | 368 ++++++++++-------- test/test-data/sqlalchemy-sql-selectable.test | 62 +++ test/testsql.py | 3 +- 5 files changed, 281 insertions(+), 172 deletions(-) create mode 100644 test/test-data/sqlalchemy-sql-selectable.test diff --git a/sqlalchemy-stubs/sql/elements.pyi b/sqlalchemy-stubs/sql/elements.pyi index 0fbc442..5fd8de6 100644 --- a/sqlalchemy-stubs/sql/elements.pyi +++ b/sqlalchemy-stubs/sql/elements.pyi @@ -2,6 +2,7 @@ from typing import ( Any, Optional, Union, Type, TypeVar, Generic, Callable, List, Dict, Set, Iterator, Iterable, Tuple as _TupleType, Mapping, overload, Text ) +from typing_extensions import Protocol from . import operators from .. import util from .visitors import Visitable as Visitable @@ -14,6 +15,7 @@ from .selectable import TextAsFrom, TableClause from .functions import FunctionElement _T = TypeVar('_T') +_T_contra = TypeVar('_T_contra', contravariant=True) _V = TypeVar('_V') _U = TypeVar('_U') @@ -142,13 +144,20 @@ class True_(ColumnElement[bool]): _CL = TypeVar('_CL', bound=ClauseList) +class _LiteralAsTextCallback(Protocol[_T_contra]): + def __call__(self, clause: _T_contra) -> List[ClauseElement]: ... + class ClauseList(ClauseElement): __visit_name__: str = ... operator: Any = ... group: bool = ... group_contents: bool = ... clauses: List[ClauseElement] = ... - def __init__(self, *clauses: ClauseElement, operator: Callable[..., Any] = ..., group: bool = ..., + @overload + def __init__(self, *clauses: _T, operator: Callable[..., Any] = ..., group: bool = ..., group_contents: bool = ..., + _literal_as_text: _LiteralAsTextCallback[_T] = ..., **kwargs: Any) -> None: ... + @overload + def __init__(self, *clauses: Optional[Union[str, bool, Visitable]], operator: Callable[..., Any] = ..., group: bool = ..., group_contents: bool = ..., **kwargs: Any) -> None: ... def __iter__(self) -> Iterator[ClauseElement]: ... def __len__(self) -> int: ... diff --git a/sqlalchemy-stubs/sql/schema.pyi b/sqlalchemy-stubs/sql/schema.pyi index 9c65eb0..2d330c0 100644 --- a/sqlalchemy-stubs/sql/schema.pyi +++ b/sqlalchemy-stubs/sql/schema.pyi @@ -1,4 +1,4 @@ -from typing import Any, Optional, Set, Generic, TypeVar, overload, Type +from typing import Any, Optional, Set, Generic, TypeVar, Type, Iterable, overload from . import visitors from .base import SchemaEventTarget as SchemaEventTarget, DialectKWArgs as DialectKWArgs, ColumnCollection from .elements import ColumnClause as ColumnClause @@ -12,13 +12,14 @@ BLANK_SCHEMA: Any = ... class SchemaItem(SchemaEventTarget, visitors.Visitable): __visit_name__: str = ... - def get_children(self, **kwargs): ... + def get_children(self, **kwargs: Any) -> Iterable[Any]: ... @property def quote(self): ... @property def info(self): ... -class Table(DialectKWArgs, SchemaItem, TableClause): +# Definition of "get_children" in base class "SchemaItem" is incompatible with definition in base class "TableClause" +class Table(DialectKWArgs, SchemaItem, TableClause): # type: ignore __visit_name__: str metadata: Any schema: Any @@ -43,7 +44,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def append_column(self, column): ... def append_constraint(self, constraint): ... def append_ddl_listener(self, event_name, listener): ... - def get_children(self, column_collections: bool = ..., schema_visitor: bool = ..., **kw): ... + def get_children(self, column_collections: bool = ..., schema_visitor: bool = ..., **kwargs): ... def exists(self, bind: Optional[Any] = ...): ... def create(self, bind: Optional[Any] = ..., checkfirst: bool = ...): ... def drop(self, bind: Optional[Any] = ..., checkfirst: bool = ...): ... diff --git a/sqlalchemy-stubs/sql/selectable.pyi b/sqlalchemy-stubs/sql/selectable.pyi index b7e9f14..b1649ae 100644 --- a/sqlalchemy-stubs/sql/selectable.pyi +++ b/sqlalchemy-stubs/sql/selectable.pyi @@ -1,45 +1,60 @@ -from typing import Any, Optional, TypeVar, Set -from .elements import ClauseElement as ClauseElement, Grouping as Grouping, UnaryExpression as UnaryExpression +from typing import Any, Optional, Union, TypeVar, List, Iterable, Sequence, Mapping, Set, Tuple, Type +from .elements import ( + ClauseElement as ClauseElement, Grouping as Grouping, UnaryExpression as UnaryExpression, ColumnElement, ColumnClause, + TextClause, Label, BindParameter +) from .base import Immutable as Immutable, Executable as Executable, Generative as Generative, ImmutableColumnCollection, ColumnSet from .annotation import Annotated as Annotated -from .schema import ForeignKey +from ..engine import Engine, Connection +from .schema import ForeignKey, Table +from .functions import Function +from .dml import Insert, Update, Delete +from .type_api import TypeEngine +from .visitors import Visitable +from .. import util -def subquery(alias, *args, **kwargs): ... -def alias(selectable, name: Optional[Any] = ..., flat: bool = ...): ... -def lateral(selectable, name: Optional[Any] = ...): ... -def tablesample(selectable, sampling, name: Optional[Any] = ..., seed: Optional[Any] = ...): ... +_T = TypeVar('_T') + +def subquery(alias: str, *args: Any, **kwargs: Any) -> Alias: ... +def alias(selectable: FromClause, name: Optional[Any] = ..., flat: bool = ...) -> Alias: ... +def lateral(selectable: FromClause, name: Optional[Any] = ...) -> Lateral: ... +def tablesample(selectable: FromClause, sampling: float, name: Optional[str] = ..., seed: Optional[Any] = ...) -> TableSample: ... + +_S = TypeVar('_S', bound=Selectable) class Selectable(ClauseElement): __visit_name__: str = ... is_selectable: bool = ... @property - def selectable(self): ... + def selectable(self: _S) -> _S: ... _HP = TypeVar('_HP', bound=HasPrefixes) class HasPrefixes(object): - def prefix_with(self: _HP, *expr, **kw) -> _HP: ... + def prefix_with(self: _HP, *expr: Union[str, ClauseElement], **kw: Any) -> _HP: ... _HS = TypeVar('_HS', bound=HasSuffixes) class HasSuffixes(object): - def suffix_with(self: _HS, *expr, **kw) -> _HS: ... + def suffix_with(self: _HS, *expr: Union[str, ClauseElement], **kw: Any) -> _HS: ... class FromClause(Selectable): __visit_name__: str = ... named_with_column: bool = ... - schema: Any = ... - def count(self, whereclause: Optional[Any] = ..., **params) -> Select: ... - def select(self, whereclause: Optional[Any] = ..., **params) -> Select: ... - def join(self, right, onclause: Optional[Any] = ..., isouter: bool = ..., full: bool = ...) -> Join: ... - def outerjoin(self, right, onclause: Optional[Any] = ..., full: bool = ...) -> Join: ... - def alias(self, name: Optional[Any] = ..., flat: bool = ...) -> Alias: ... - def lateral(self, name: Optional[Any] = ...) -> Lateral: ... - def tablesample(self, sampling, name: Optional[Any] = ..., seed: Optional[Any] = ...) -> TableSample: ... - def is_derived_from(self, fromclause) -> bool: ... - def replace_selectable(self, old, alias): ... - def correspond_on_equivalents(self, column, equivalents): ... - def corresponding_column(self, column, require_embedded: bool = ...): ... + schema: Optional[str] = ... + def count(self, whereclause: Optional[Union[str, bool, Visitable]] = ..., **params: Any) -> Select: ... + def select(self, whereclause: Optional[Union[str, bool, Visitable]] = ..., **params: Any) -> Select: ... + def join(self, right: FromClause, onclause: Optional[ClauseElement] = ..., isouter: bool = ..., full: bool = ...) -> Join: ... + def outerjoin(self, right: FromClause, onclause: Optional[FromClause] = ..., full: bool = ...) -> Join: ... + def alias(self, name: Optional[str] = ..., flat: bool = ...) -> Alias: ... + def lateral(self, name: Optional[str] = ...) -> Lateral: ... + def tablesample(self, sampling: Union[float, Function[float]], name: Optional[str] = ..., + seed: Optional[Any] = ...) -> TableSample: ... + def is_derived_from(self, fromclause: FromClause) -> bool: ... + def replace_selectable(self, old: FromClause, alias: Alias) -> FromClause: ... + def correspond_on_equivalents(self, column: ColumnElement[Any], + equivalents: Mapping[Any, Any]) -> Optional[ColumnElement[Any]]: ... + def corresponding_column(self, column: ColumnElement[Any], require_embedded: bool = ...) -> ColumnElement[Any]: ... @property def description(self) -> str: ... @property @@ -50,45 +65,52 @@ class FromClause(Selectable): def foreign_keys(self) -> Set[ForeignKey]: ... c: ImmutableColumnCollection = ... +_J = TypeVar('_J', bound=Join) + class Join(FromClause): __visit_name__: str = ... - left: Any = ... - right: Any = ... - onclause: Any = ... - isouter: Any = ... - full: Any = ... - def __init__(self, left, right, onclause: Optional[Any] = ..., isouter: bool = ..., full: bool = ...) -> None: ... + left: FromClause = ... + right: FromClause = ... + onclause: ClauseElement = ... + isouter: bool = ... + full: bool = ... + def __init__(self, left: FromClause, right: FromClause, onclause: Optional[ClauseElement] = ..., + isouter: bool = ..., full: bool = ...) -> None: ... @property - def description(self): ... - def is_derived_from(self, fromclause): ... - def self_group(self, against: Optional[Any] = ...): ... - def get_children(self, **kwargs): ... - def select(self, whereclause: Optional[Any] = ..., **kwargs): ... + def description(self) -> str: ... + def is_derived_from(self, fromclause: FromClause) -> bool: ... + def self_group(self, against: Optional[Any] = ...) -> FromGrouping: ... + def get_children(self, **kwargs: Any) -> Tuple[FromClause, FromClause, ClauseElement]: ... + def select(self, whereclause: Optional[Union[str, bool, Visitable]] = ..., **kwargs: Any) -> Select: ... @property - def bind(self): ... - def alias(self, *args, **kwargs): ... + def bind(self) -> Optional[Union[Engine, Connection]]: ... + # Return type of "alias" incompatible with supertype "FromClause" + def alias(self, name: Optional[str] = ..., flat: bool = ...) -> Union[Alias, Join]: ... # type: ignore @classmethod - def _create_outerjoin(cls, left, right, onclause: Optional[Any] = ..., full: bool = ...) -> Join: ... + def _create_outerjoin(cls: Type[_J], left: FromClause, right: FromClause, onclause: Optional[ClauseElement] = ..., + full: bool = ...) -> _J: ... @classmethod - def _create_join(cls, left, right, onclause: Optional[Any] = ..., isouter: bool = ..., - full: bool = ...) -> Join: ... + def _create_join(cls: Type[_J], left: FromClause, right: FromClause, onclause: Optional[ClauseElement] = ..., + isouter: bool = ..., full: bool = ...) -> _J: ... + +_A = TypeVar('_A', bound=Alias) class Alias(FromClause): __visit_name__: str = ... named_with_column: bool = ... - original: Any = ... - supports_execution: Any = ... - element: Any = ... - name: Any = ... - def __init__(self, selectable, name: Optional[Any] = ...) -> None: ... - def self_group(self, target: Optional[Any] = ...): ... + original: Selectable = ... + supports_execution: bool = ... + element: Selectable = ... + name: Optional[str] = ... + def __init__(self, selectable: Selectable, name: Optional[str] = ...) -> None: ... + def self_group(self: _A, against: Optional[Any] = ...) -> Union[FromGrouping, _A]: ... @property - def description(self): ... - def as_scalar(self): ... - def is_derived_from(self, fromclause): ... - def get_children(self, column_collections: bool = ..., **kw): ... + def description(self) -> str: ... + def as_scalar(self) -> Any: ... + def is_derived_from(self, fromclause: FromClause) -> bool: ... + def get_children(self, column_collections: bool = ..., **kw: Any) -> Iterable[Union[ColumnElement[Any], Selectable]]: ... @property - def bind(self): ... + def bind(self) -> Optional[Union[Engine, Connection]]: ... class Lateral(Alias): __visit_name__: str = ... @@ -97,195 +119,209 @@ class TableSample(Alias): __visit_name__: str = ... sampling: Any = ... seed: Any = ... - def __init__(self, selectable, sampling, name: Optional[Any] = ..., seed: Optional[Any] = ...) -> None: ... + def __init__(self, selectable: FromClause, sampling: Union[float, Function[float]], name: Optional[str] = ..., + seed: Optional[Any] = ...) -> None: ... class CTE(Generative, HasSuffixes, Alias): __visit_name__: str = ... - recursive: Any = ... - def __init__(self, selectable, name: Optional[Any] = ..., recursive: bool = ..., + recursive: bool = ... + def __init__(self, selectable: Select, name: Optional[str] = ..., recursive: bool = ..., _cte_alias: Optional[Any] = ..., _restates: Any = ..., _suffixes: Optional[Any] = ...) -> None: ... - def alias(self, name: Optional[Any] = ..., flat: bool = ...) -> CTE: ... - def union(self, other) -> CTE: ... - def union_all(self, other) -> CTE: ... + def alias(self, name: Optional[str] = ..., flat: bool = ...) -> CTE: ... + def union(self, other: Select) -> CTE: ... + def union_all(self, other: Select) -> CTE: ... class HasCTE(object): - def cte(self, name: Optional[Any] = ..., recursive: bool = ...) -> CTE: ... + def cte(self, name: Optional[str] = ..., recursive: bool = ...) -> CTE: ... class FromGrouping(FromClause): __visit_name__: str = ... - element: Any = ... - def __init__(self, element) -> None: ... + element: FromClause = ... + def __init__(self, element: FromClause) -> None: ... @property def columns(self) -> ImmutableColumnCollection: ... @property def primary_key(self) -> ColumnSet: ... @property def foreign_keys(self) -> Set[ForeignKey]: ... - def is_derived_from(self, element): ... - def alias(self, **kw): ... - def get_children(self, **kwargs): ... - def __getattr__(self, attr): ... + def is_derived_from(self, element: FromClause) -> bool: ... + # Return type of "alias" incompatible with supertype "FromClause" + def alias(self, name: Optional[str] = ..., flat: bool = ...) -> FromGrouping: ... # type: ignore + def get_children(self, **kwargs: Any) -> Tuple[FromClause]: ... + def __getattr__(self, attr: str) -> Any: ... class TableClause(Immutable, FromClause): __visit_name__: str = ... named_with_column: bool = ... implicit_returning: bool = ... - name: Any = ... + name: str = ... primary_key: ColumnSet = ... foreign_keys: Set[ForeignKey] = ... - def __init__(self, name, *columns) -> None: ... + def __init__(self, name: str, *columns: ColumnClause[Any]) -> None: ... @property - def description(self): ... - def append_column(self, c): ... - def get_children(self, **kwargs): ... - def insert(self, values: Optional[Any] = ..., inline: bool = ..., **kwargs): ... - def update(self, whereclause: Optional[Any] = ..., values: Optional[Any] = ..., inline: bool = ..., - **kwargs): ... - def delete(self, whereclause: Optional[Any] = ..., **kwargs): ... + def description(self) -> str: ... + def append_column(self, c: ColumnClause[Any]): ... + def get_children(self, column_collections: bool = ..., **kwargs: Any) -> List[ColumnClause[Any]]: ... + def insert(self, values: Optional[Union[Mapping[str, Any], + Mapping[ColumnClause[Any], Any], + Mapping[Union[str, ColumnClause[Any]], Any]]] = ..., + inline: bool = ..., **kwargs: Any) -> Insert: ... + def update(self, whereclause: Optional[Union[str, bool, Visitable]] = ..., + values: Optional[Mapping[Union[str, ColumnClause], Any]] = ..., + inline: bool = ..., **kwargs: Any) -> Update: ... + def delete(self, whereclause: Optional[Union[str, bool, Visitable]] = ..., **kwargs: Any) -> Delete: ... class ForUpdateArg(ClauseElement): @classmethod - def parse_legacy_select(self, arg): ... + def parse_legacy_select(self, arg: Optional[str]) -> Optional[ForUpdateArg]: ... @property - def legacy_for_update_value(self): ... - nowait: Any = ... - read: Any = ... - skip_locked: Any = ... - key_share: Any = ... + def legacy_for_update_value(self) -> Union[str, bool]: ... + nowait: bool = ... + read: bool = ... + skip_locked: bool = ... + key_share: bool = ... of: Any = ... - def __init__(self, nowait: bool = ..., read: bool = ..., of: Optional[Any] = ..., + def __init__(self, nowait: bool = ..., read: bool = ..., of: Optional[Union[TextClause, Sequence[ColumnClause[Any]]]] = ..., skip_locked: bool = ..., key_share: bool = ...) -> None: ... _SB = TypeVar('_SB', bound=SelectBase) class SelectBase(HasCTE, Executable, FromClause): - def as_scalar(self): ... - def label(self, name): ... + def as_scalar(self) -> ScalarSelect[Any]: ... + def label(self, name: str) -> Label: ... def autocommit(self: _SB) -> _SB: ... _GS = TypeVar('_GS', bound=GenerativeSelect) class GenerativeSelect(SelectBase): - use_labels: Any = ... - def __init__(self, use_labels: bool = ..., for_update: bool = ..., limit: Optional[Any] = ..., - offset: Optional[Any] = ..., order_by: Optional[Any] = ..., - group_by: Optional[Any] = ..., bind: Optional[Any] = ..., autocommit: Optional[Any] = ...) -> None: ... - @property - def for_update(self): ... - @for_update.setter - def for_update(self, value): ... - def with_for_update(self: _GS, nowait: bool = ..., read: bool = ..., of: Optional[Any] = ..., + use_labels: bool = ... + for_update: Union[str, bool] = ... + def __init__(self, use_labels: bool = ..., for_update: bool = ..., limit: Optional[int] = ..., + offset: Optional[int] = ..., + order_by: Optional[Union[int, str, Visitable, Iterable[Union[int, str, Visitable]]]] = ..., + group_by: Optional[Union[int, str, Visitable, Iterable[Union[int, str, Visitable]]]] = ..., + bind: Optional[Union[Engine, Connection]] = ..., + autocommit: Optional[bool] = ...) -> None: ... + def with_for_update(self: _GS, nowait: bool = ..., read: bool = ..., + of: Optional[Union[TextClause, Sequence[ColumnClause[Any]]]] = ..., skip_locked: bool = ..., key_share: bool = ...) -> _GS: ... def apply_labels(self: _GS) -> _GS: ... - def limit(self: _GS, limit) -> _GS: ... - def offset(self: _GS, offset) -> _GS: ... - def order_by(self: _GS, *clauses) -> _GS: ... - def group_by(self: _GS, *clauses) -> _GS: ... - def append_order_by(self, *clauses): ... - def append_group_by(self, *clauses): ... + def limit(self: _GS, limit: Optional[Union[int, str, Visitable]]) -> _GS: ... + def offset(self: _GS, offset: Optional[Union[int, str, Visitable]]) -> _GS: ... + def order_by(self: _GS, *clauses: Optional[Union[str, bool, Visitable]]) -> _GS: ... + def group_by(self: _GS, *clauses: Optional[Union[str, bool, Visitable]]) -> _GS: ... + def append_order_by(self, *clauses: Optional[Union[str, bool, Visitable]]): ... + def append_group_by(self, *clauses: Optional[Union[str, bool, Visitable]]): ... class CompoundSelect(GenerativeSelect): __visit_name__: str = ... - UNION: Any = ... - UNION_ALL: Any = ... - EXCEPT: Any = ... - EXCEPT_ALL: Any = ... - INTERSECT: Any = ... - INTERSECT_ALL: Any = ... - keyword: Any = ... - selects: Any = ... - def __init__(self, keyword, *selects, **kwargs) -> None: ... - def self_group(self, against: Optional[Any] = ...): ... - def is_derived_from(self, fromclause): ... - def get_children(self, column_collections: bool = ..., **kwargs): ... - def bind(self): ... + UNION: util.symbol = ... + UNION_ALL: util.symbol = ... + EXCEPT: util.symbol = ... + EXCEPT_ALL: util.symbol = ... + INTERSECT: util.symbol = ... + INTERSECT_ALL: util.symbol = ... + keyword: util.symbol = ... + selects: List[Selectable] = ... + def __init__(self, keyword: util.symbol, *selects: Selectable, **kwargs: Any) -> None: ... + def self_group(self, against: Optional[Any] = ...) -> FromGrouping: ... + def is_derived_from(self, fromclause: FromClause): ... + def get_children(self, column_collections: bool = ..., + **kwargs: Any) -> List[Union[ColumnClause[Any], ClauseElement, Selectable]]: ... + def bind(self) -> Optional[Union[Engine, Connection]]: ... @classmethod - def _create_union(cls, *selects, **kwargs) -> CompoundSelect: ... + def _create_union(cls, *selects: Selectable, **kwargs: Any) -> CompoundSelect: ... @classmethod - def _create_union_all(cls, *selects, **kwargs) -> CompoundSelect: ... + def _create_union_all(cls, *selects: Selectable, **kwargs: Any) -> CompoundSelect: ... @classmethod - def _create_except(cls, *selects, **kwargs) -> CompoundSelect: ... + def _create_except(cls, *selects: Selectable, **kwargs: Any) -> CompoundSelect: ... @classmethod - def _create_except_all(cls, *selects, **kwargs) -> CompoundSelect: ... + def _create_except_all(cls, *selects: Selectable, **kwargs: Any) -> CompoundSelect: ... @classmethod - def _create_intersect(cls, *selects, **kwargs) -> CompoundSelect: ... + def _create_intersect(cls, *selects: Selectable, **kwargs: Any) -> CompoundSelect: ... @classmethod - def _create_intersect_all(cls, *selects, **kwargs) -> CompoundSelect: ... + def _create_intersect_all(cls, *selects: Selectable, **kwargs: Any) -> CompoundSelect: ... -_S = TypeVar('_S', bound=Select) +_SE = TypeVar('_SE', bound=Select) class Select(HasPrefixes, HasSuffixes, GenerativeSelect): __visit_name__: str = ... - def __init__(self, columns: Optional[Any] = ..., whereclause: Optional[Any] = ..., - from_obj: Optional[Any] = ..., distinct: bool = ..., having: Optional[Any] = ..., - correlate: bool = ..., prefixes: Optional[Any] = ..., suffixes: Optional[Any] = ..., - **kwargs) -> None: ... + def __init__(self, columns: Optional[Iterable[Union[ColumnElement[Any], FromClause]]] = ..., + whereclause: Optional[Union[str, bool, Visitable]] = ..., + from_obj: Optional[Union[str, Selectable, Iterable[Union[str, Selectable]]]] = ..., + group_by: Optional[Union[int, str, Visitable, Iterable[Union[int, str, Visitable]]]] = ..., + having: Optional[Union[str, bool, Visitable]] = ..., + order_by: Optional[Union[int, str, Visitable, Iterable[Union[int, str, Visitable]]]] = ..., + distinct: bool = ..., correlate: bool = ..., limit: Optional[int] = ..., offset: Optional[int] = ..., + use_labels: bool = ..., autocommit: bool = ..., bind: Union[Engine, Connection] = ..., + prefixes: Optional[Any] = ..., suffixes: Optional[Any] = ..., + **kwargs: Any) -> None: ... @property - def froms(self): ... - def with_statement_hint(self, text, dialect_name: str = ...): ... - def with_hint(self: _S, selectable, text, dialect_name: str = ...) -> _S: ... + def froms(self) -> List[FromClause]: ... + def with_statement_hint(self: _SE, text: str, dialect_name: str = ...) -> _SE: ... + def with_hint(self: _SE, selectable: Union[Table, Alias], text: str, dialect_name: str = ...) -> _SE: ... @property - def type(self): ... + def type(self) -> Any: ... @property - def locate_all_froms(self): ... + def locate_all_froms(self) -> List[FromClause]: ... @property - def inner_columns(self): ... - def is_derived_from(self, fromclause): ... - def get_children(self, column_collections: bool = ..., **kwargs): ... - def column(self: _S, column) -> _S: ... - def reduce_columns(self, only_synonyms: bool = ...): ... - def with_only_columns(self: _S, columns): ... - def where(self: _S, whereclause) -> _S: ... - def having(self: _S, having) -> _S: ... - def distinct(self: _S, *expr) -> _S: ... - def select_from(self: _S, fromclause) -> _S: ... - def correlate(self: _S, *fromclauses) -> _S: ... - def correlate_except(self: _S, *fromclauses) -> _S: ... - def append_correlation(self, fromclause): ... - def append_column(self, column): ... - def append_prefix(self, clause): ... - def append_whereclause(self, whereclause): ... - def append_having(self, having): ... - def append_from(self, fromclause): ... - def self_group(self, against: Optional[Any] = ...): ... - def union(self, other, **kwargs): ... - def union_all(self, other, **kwargs): ... - def except_(self, other, **kwargs): ... - def except_all(self, other, **kwargs): ... - def intersect(self, other, **kwargs): ... - def intersect_all(self, other, **kwargs): ... - def bind(self): ... + def inner_columns(self) -> Iterable[ColumnElement[Any]]: ... + def is_derived_from(self, fromclause: FromClause) -> bool: ... + def get_children(self, column_collections: bool = ..., **kwargs: Any) -> List[ClauseElement]: ... + def column(self: _SE, column: ColumnElement[Any]) -> _SE: ... + def reduce_columns(self: _SE, only_synonyms: bool = ...) -> _SE: ... + def with_only_columns(self: _SE, columns: Iterable[ColumnElement[Any]]) -> _SE: ... + def where(self: _SE, whereclause: Union[str, bool, Visitable]) -> _SE: ... + def having(self: _SE, having: Union[str, bool, Visitable]) -> _SE: ... + def distinct(self: _SE, *expr: ColumnElement[Any]) -> _SE: ... + def select_from(self: _SE, fromclause: FromClause) -> _SE: ... + def correlate(self: _SE, *fromclauses: FromClause) -> _SE: ... + def correlate_except(self: _SE, *fromclauses: FromClause) -> _SE: ... + def append_correlation(self, fromclause: FromClause) -> None: ... + def append_column(self, column: ColumnElement[Any]) -> None: ... + def append_prefix(self, clause) -> None: ... + def append_whereclause(self, whereclause: Union[str, bool, Visitable]) -> None: ... + def append_having(self, having: Union[str, bool, Visitable]) -> None: ... + def append_from(self, fromclause: FromClause) -> None: ... + def self_group(self: _SE, against: Optional[Any] = ...) -> Union[_SE, FromGrouping]: ... + def union(self, other: Selectable, **kwargs: Any) -> CompoundSelect: ... + def union_all(self, other: Selectable, **kwargs: Any) -> CompoundSelect: ... + def except_(self, other: Selectable, **kwargs: Any) -> CompoundSelect: ... + def except_all(self, other: Selectable, **kwargs: Any) -> CompoundSelect: ... + def intersect(self, other: Selectable, **kwargs: Any) -> CompoundSelect: ... + def intersect_all(self, other: Selectable, **kwargs: Any) -> CompoundSelect: ... + def bind(self) -> Optional[Union[Engine, Connection]]: ... _SS = TypeVar('_SS', bound=ScalarSelect) -class ScalarSelect(Generative, Grouping): - element: Any = ... - type: Any = ... - def __init__(self, element) -> None: ... +class ScalarSelect(Generative, Grouping[_T]): + element: ClauseElement = ... + type: TypeEngine[_T] = ... + def __init__(self, element: ClauseElement) -> None: ... @property def columns(self): ... c: Any = ... - def where(self: _SS, crit) -> _SS: ... - def self_group(self, **kwargs): ... + def where(self: _SS, crit: ClauseElement) -> _SS: ... + def self_group(self: _SS, **kwargs: Any) -> _SS: ... # type: ignore # return type incompatible with all supertypes class Exists(UnaryExpression): __visit_name__: Any = ... - def __init__(self, *args, **kwargs) -> None: ... - def select(self, whereclause: Optional[Any] = ..., **params): ... - def correlate(self, *fromclause): ... - def correlate_except(self, *fromclause): ... - def select_from(self, clause): ... - def where(self, clause): ... + def __init__(self, *args: Union[Select, str], **kwargs: Any) -> None: ... + def select(self, whereclause: Optional[Union[str, bool, Visitable]] = ..., **params: Any) -> Select: ... + def correlate(self, *fromclause: FromClause) -> Exists: ... + def correlate_except(self, *fromclause: FromClause) -> Exists: ... + def select_from(self, clause: FromClause) -> Exists: ... + def where(self, clause: ClauseElement) -> Exists: ... _TAF = TypeVar('_TAF', bound=TextAsFrom) class TextAsFrom(SelectBase): __visit_name__: str = ... - element: Any = ... + element: TextClause = ... column_args: Any = ... positional: Any = ... - def __init__(self, text, columns, positional: bool = ...) -> None: ... - def bindparams(self: _TAF, *binds, **bind_as_values) -> _TAF: ... + def __init__(self, text: TextClause, columns: ColumnClause[Any], positional: bool = ...) -> None: ... + def bindparams(self: _TAF, *binds: BindParameter[Any], **bind_as_values: Any) -> _TAF: ... class AnnotatedFromClause(Annotated): - def __init__(self, element, values) -> None: ... + def __init__(self, element: FromClause, values: Any) -> None: ... diff --git a/test/test-data/sqlalchemy-sql-selectable.test b/test/test-data/sqlalchemy-sql-selectable.test new file mode 100644 index 0000000..b7dc937 --- /dev/null +++ b/test/test-data/sqlalchemy-sql-selectable.test @@ -0,0 +1,62 @@ +[case testSelectableSelect] +from sqlalchemy import table, select, column, func + +t = table('user', column('id'), column('name'), column('description'), column('articles')) +s = select(columns=[column('id'), column('name')], + whereclause=column('name') == 'foo', + from_obj=t, + group_by=column('name'), + having=func.count(column('articles')) > 2, + order_by=column('id'), + distinct=True, + correlate=False, + limit=10, + offset=30) +[out] + +[case testSelectableSelectIterables] +from sqlalchemy import table, select, column, alias, func + +t = table('user', column('id'), column('name'), column('description'), column('articles')) +st = table('stuff', column('id')) +a = alias(t, 'my_user') +s = select(columns={column('id'), column('name'), a}, + whereclause='name == "foo"', + from_obj=[t, 'stuff'], + group_by=[column('name'), 'description'], + having='count(articles) > 2', + order_by=[column('id'), 'name', 1]) +[out] + +[case testSelectableSelectLimit] +from sqlalchemy import select + +select().limit(None) +select().limit(10) +[out] + +[case testSelectableSelectOffset] +from sqlalchemy import select + +select().offset(None) +select().offset(10) +[out] + +[case testSelectableSelectOrderByGroupBy] +from sqlalchemy import select, column + +select() \ + .order_by('name', column('description')) \ + .group_by('name', column('description')) +[out] + +[case testSelectableTableClauseInsert] +from typing import Any, Dict, Union +from sqlalchemy import table, column, String + +t = table('user', column('id'), column('name'), column('description'), column('articles')) +t.insert({'name': 'foo'}) +t.insert({column('name', String): 'foo'}) +# t.insert({column('name', String): 'foo', +# 'description': 'bar'}) +[out] diff --git a/test/testsql.py b/test/testsql.py index a162fa9..33e8424 100644 --- a/test/testsql.py +++ b/test/testsql.py @@ -22,7 +22,8 @@ class SQLDataSuite(DataSuite): files = ['sqlalchemy-basics.test', 'sqlalchemy-sql-elements.test', - 'sqlalchemy-sql-sqltypes.test'] + 'sqlalchemy-sql-sqltypes.test', + 'sqlalchemy-sql-selectable.test'] data_prefix = test_data_prefix def run_case(self, testcase: DataDrivenTestCase) -> None: