Skip to content

Commit

Permalink
typing(discover): Adding some of the files for QueryBuilder to mypy (#…
Browse files Browse the repository at this point in the history
…29851)

- This adds a few files to mypy, didn't include search/events/filter or
  fields yet
  • Loading branch information
wmak authored Nov 11, 2021
1 parent 66e0e98 commit f5ecc4a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 44 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ files = src/sentry/api/bases/external_actor.py,
src/sentry/snuba/outcomes.py,
src/sentry/snuba/query_subscription_consumer.py,
src/sentry/spans/**/*.py,
src/sentry/search/events/base.py,
src/sentry/search/events/builder.py,
src/sentry/search/events/types.py,
src/sentry/tasks/app_store_connect.py,
src/sentry/tasks/low_priority_symbolication.py,
src/sentry/tasks/store.py,
Expand Down
6 changes: 3 additions & 3 deletions src/sentry/search/events/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Mapping, Optional, Set
from typing import Dict, List, Mapping, Optional, Set, cast

from django.utils.functional import cached_property
from snuba_sdk.aliased_expression import AliasedExpression
Expand Down Expand Up @@ -37,9 +37,9 @@ def __init__(

self.resolve_column_name = resolve_column(self.dataset)

@cached_property
@cached_property # type: ignore
def project_slugs(self) -> Mapping[str, int]:
project_ids = self.params.get("project_id", [])
project_ids = cast(List[int], self.params.get("project_id", []))

if len(project_ids) > 0:
project_slugs = Project.objects.filter(id__in=project_ids)
Expand Down
41 changes: 22 additions & 19 deletions src/sentry/search/events/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, cast

from snuba_sdk.aliased_expression import AliasedExpression
from snuba_sdk.column import Column
Expand All @@ -16,7 +16,7 @@
from sentry.utils.snuba import Dataset


class QueryBuilder(QueryFilter):
class QueryBuilder(QueryFilter): # type: ignore
"""Builds a snql query"""

def __init__(
Expand All @@ -35,7 +35,7 @@ def __init__(
limit: Optional[int] = 50,
offset: Optional[int] = 0,
limitby: Optional[Tuple[str, int]] = None,
turbo: Optional[bool] = False,
turbo: bool = False,
sample_rate: Optional[float] = None,
):
super().__init__(dataset, params, auto_fields, functions_acl)
Expand Down Expand Up @@ -86,7 +86,7 @@ def groupby(self) -> Optional[List[SelectType]]:
else:
return []

def validate_aggregate_arguments(self):
def validate_aggregate_arguments(self) -> None:
for column in self.columns:
if column in self.aggregates:
continue
Expand Down Expand Up @@ -149,7 +149,7 @@ def get_snql_query(self) -> Query:
)


class TimeseriesQueryBuilder(QueryFilter):
class TimeseriesQueryBuilder(QueryFilter): # type: ignore
time_column = Column("time")

def __init__(
Expand Down Expand Up @@ -185,7 +185,8 @@ def __init__(
def select(self) -> List[SelectType]:
if not self.aggregates:
raise InvalidSearchQuery("Cannot query a timeseries without a Y-Axis")
return self.aggregates
# Casting for now since QueryFields/QueryFilter are only partially typed
return cast(List[SelectType], self.aggregates)

def get_snql_query(self) -> Query:
return Query(
Expand Down Expand Up @@ -247,14 +248,16 @@ def __init__(
params: ParamsType,
granularity: int,
top_events: List[Dict[str, Any]],
other: Optional[bool] = False,
other: bool = False,
query: Optional[str] = None,
selected_columns: Optional[List[str]] = None,
timeseries_columns: Optional[List[str]] = None,
equations: Optional[List[str]] = None,
limit: Optional[int] = 10000,
):
timeseries_equations, timeseries_functions = categorize_columns(timeseries_columns)
timeseries_equations, timeseries_functions = categorize_columns(
timeseries_columns if timeseries_columns is not None else []
)
super().__init__(
dataset,
params,
Expand All @@ -265,7 +268,7 @@ def __init__(
limit=limit,
)

self.fields = selected_columns
self.fields: List[str] = selected_columns if selected_columns is not None else []

if (conditions := self.resolve_top_event_conditions(top_events, other)) is not None:
self.where.append(conditions)
Expand All @@ -290,7 +293,7 @@ def translated_groupby(self) -> List[str]:
return sorted(translated)

def resolve_top_event_conditions(
self, top_events: Optional[Dict[str, Any]], other: bool
self, top_events: List[Dict[str, Any]], other: bool
) -> Optional[WhereType]:
"""Given a list of top events construct the conditions"""
conditions = []
Expand All @@ -314,7 +317,7 @@ def resolve_top_event_conditions(

resolved_field = self.resolve_column(field)

values = set()
values: Set[Any] = set()
for event in top_events:
if field in event:
alias = field
Expand All @@ -328,29 +331,29 @@ def resolve_top_event_conditions(
continue
else:
values.add(event.get(alias))
values = list(values)
if values:
values_list = list(values)
if values_list:
if field == "timestamp" or field.startswith("timestamp.to_"):
if not other:
# timestamp fields needs special handling, creating a big OR instead
function, operator = Or, Op.EQ
else:
# Needs to be a big AND when negated
function, operator = And, Op.NEQ
if len(values) > 1:
if len(values_list) > 1:
conditions.append(
function(
conditions=[
Condition(resolved_field, operator, value)
for value in sorted(values)
for value in sorted(values_list)
]
)
)
else:
conditions.append(Condition(resolved_field, operator, values[0]))
elif None in values:
conditions.append(Condition(resolved_field, operator, values_list[0]))
elif None in values_list:
# one of the values was null, but we can't do an in with null values, so split into two conditions
non_none_values = [value for value in values if value is not None]
non_none_values = [value for value in values_list if value is not None]
null_condition = Condition(
Function("isNull", [resolved_field]), Op.EQ if not other else Op.NEQ, 1
)
Expand All @@ -366,7 +369,7 @@ def resolve_top_event_conditions(
conditions.append(null_condition)
else:
conditions.append(
Condition(resolved_field, Op.IN if not other else Op.NOT_IN, values)
Condition(resolved_field, Op.IN if not other else Op.NOT_IN, values_list)
)
if len(conditions) > 1:
final_function = And if not other else Or
Expand Down
14 changes: 8 additions & 6 deletions src/sentry/search/events/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,12 +2279,12 @@ def normalize_percentile_alias(args: Mapping[str, str]) -> str:


class SnQLFunction(DiscoverFunction):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self.snql_aggregate = kwargs.pop("snql_aggregate", None)
self.snql_column = kwargs.pop("snql_column", None)
super().__init__(*args, **kwargs)

def validate(self):
def validate(self) -> None:
# assert that all optional args have defaults available
for i, arg in enumerate(self.optional_args):
assert (
Expand Down Expand Up @@ -2989,7 +2989,7 @@ def is_function(self, function: str) -> bool:
return function in self.function_converter

def resolve_function(
self, function: str, match: Optional[Match[str]] = None, resolve_only=False
self, function: str, match: Optional[Match[str]] = None, resolve_only: bool = False
) -> SelectType:
"""Given a public function, resolve to the corresponding Snql function
Expand All @@ -3007,7 +3007,7 @@ def resolve_function(
if function in self.params.get("aliases", {}):
raise NotImplementedError("Aggregate aliases not implemented in snql field parsing yet")

name, combinator_name, arguments, alias = self.parse_function(match)
name, combinator_name, parsed_arguments, alias = self.parse_function(match)
snql_function = self.function_converter[name]

combinator = snql_function.find_combinator(combinator_name)
Expand All @@ -3022,7 +3022,9 @@ def resolve_function(

combinator_applied = False

arguments = snql_function.format_as_arguments(name, arguments, self.params, combinator)
arguments = snql_function.format_as_arguments(
name, parsed_arguments, self.params, combinator
)

self.function_alias_map[alias] = FunctionDetails(function, snql_function, arguments.copy())

Expand Down Expand Up @@ -3384,7 +3386,7 @@ def _resolve_percentile(
self,
args: Mapping[str, Union[str, Column, SelectType, int, float]],
alias: str,
fixed_percentile: float = None,
fixed_percentile: Optional[float] = None,
) -> SelectType:
return (
Function(
Expand Down
35 changes: 19 additions & 16 deletions src/sentry/search/events/filter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from datetime import datetime
from functools import reduce
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -67,7 +69,7 @@ def is_condition(term):
return isinstance(term, (tuple, list)) and len(term) == 3 and term[1] in OPERATOR_TO_FUNCTION


def translate_transaction_status(val):
def translate_transaction_status(val: str) -> str:
if val not in SPAN_STATUS_NAME_TO_CODE:
raise InvalidSearchQuery(
f"Invalid value {val} for transaction.status condition. Accepted "
Expand Down Expand Up @@ -503,7 +505,7 @@ def _semver_build_filter_converter(
return ["release", "IN", versions]


def handle_operator_negation(operator):
def handle_operator_negation(operator: str) -> Tuple[str, bool]:
negated = False
if operator == "!=":
negated = True
Expand Down Expand Up @@ -1178,7 +1180,9 @@ def resolve_boolean_conditions(

return where, having

def _combine_conditions(self, lhs, rhs, operator):
def _combine_conditions(
self, lhs: List[WhereType], rhs: List[WhereType], operator: And | Or
) -> List[WhereType]:
combined_conditions = [
conditions[0] if len(conditions) == 1 else And(conditions=conditions)
for conditions in [lhs, rhs]
Expand Down Expand Up @@ -1441,9 +1445,9 @@ def _environment_filter_converter(self, search_filter: SearchFilter) -> Optional
# conditions added to env_conditions can be OR'ed
env_conditions = []
value = search_filter.value.value
values = set(value if isinstance(value, (list, tuple)) else [value])
values_set = set(value if isinstance(value, (list, tuple)) else [value])
# sorted for consistency
values = sorted(f"{value}" for value in values)
values = sorted(f"{value}" for value in values_set)
environment = self.column("environment")
# the "no environment" environment is null in snuba
if "" in values:
Expand Down Expand Up @@ -1564,12 +1568,11 @@ def _transaction_status_filter_converter(
self.resolve_field(search_filter.key.name),
Op.IS_NULL if search_filter.operator == "=" else Op.IS_NOT_NULL,
)
if search_filter.is_in_filter:
internal_value = [
translate_transaction_status(val) for val in search_filter.value.raw_value
]
else:
internal_value = translate_transaction_status(search_filter.value.raw_value)
internal_value = (
[translate_transaction_status(val) for val in search_filter.value.raw_value]
if search_filter.is_in_filter
else translate_transaction_status(search_filter.value.raw_value)
)
return Condition(
self.resolve_field(search_filter.key.name),
Op(search_filter.operator),
Expand Down Expand Up @@ -1661,8 +1664,8 @@ def _release_stage_filter_converter(self, search_filter: SearchFilter) -> Option
raise ValueError("organization_id is a required param")

organization_id: int = self.params["organization_id"]
project_ids: Optional[list[int]] = self.params.get("project_id")
environments: Optional[list[Environment]] = self.params.get("environment_objects", [])
project_ids: Optional[List[int]] = self.params.get("project_id")
environments: Optional[List[Environment]] = self.params.get("environment_objects", [])
qs = (
Release.objects.filter_by_stage(
organization_id,
Expand Down Expand Up @@ -1729,7 +1732,7 @@ def _semver_filter_converter(self, search_filter: SearchFilter) -> Optional[Wher
raise ValueError("organization_id is a required param")

organization_id: int = self.params["organization_id"]
project_ids: Optional[list[int]] = self.params.get("project_id")
project_ids: Optional[List[int]] = self.params.get("project_id")
# We explicitly use `raw_value` here to avoid converting wildcards to shell values
version: str = search_filter.value.raw_value
operator: str = search_filter.operator
Expand Down Expand Up @@ -1787,7 +1790,7 @@ def _semver_package_filter_converter(self, search_filter: SearchFilter) -> Optio
raise ValueError("organization_id is a required param")

organization_id: int = self.params["organization_id"]
project_ids: Optional[list[int]] = self.params.get("project_id")
project_ids: Optional[List[int]] = self.params.get("project_id")
package: str = search_filter.value.raw_value

versions = list(
Expand All @@ -1813,7 +1816,7 @@ def _semver_build_filter_converter(self, search_filter: SearchFilter) -> Optiona
raise ValueError("organization_id is a required param")

organization_id: int = self.params["organization_id"]
project_ids: Optional[list[int]] = self.params.get("project_id")
project_ids: Optional[List[int]] = self.params.get("project_id")
build: str = search_filter.value.raw_value

operator, negated = handle_operator_negation(search_filter.operator)
Expand Down

0 comments on commit f5ecc4a

Please sign in to comment.