Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: always denorm column value before querying values #25919

Merged
merged 8 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,6 @@ def query(self, query_obj: QueryObjectDict) -> QueryResult:
"""
raise NotImplementedError()

def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
values in filters in the explore view"""
raise NotImplementedError()

@staticmethod
def default_query(qry: Query) -> Query:
return qry
Expand Down
29 changes: 0 additions & 29 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
inspect,
Integer,
or_,
select,
String,
Table,
Text,
Expand Down Expand Up @@ -793,34 +792,6 @@ def get_fetch_values_predicate(
)
) from ex

def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)

qry = (
select([target_col.get_sqla_col(template_processor=tp)])
.select_from(tbl)
.distinct()
)
if limit:
qry = qry.limit(limit)

if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))

with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)

df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()

def mutate_query_from_config(self, sql: str) -> str:
"""Apply config's SQL_QUERY_MUTATOR

Expand Down
4 changes: 4 additions & 0 deletions superset/datasource/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def get_column_values(
column_name=column_name, limit=row_limit
)
return self.response(200, result=payload)
except KeyError:
return self.response(
400, message=f"Column name {column_name} does not exist"
)
except NotImplementedError:
return self.response(
400,
Expand Down
56 changes: 27 additions & 29 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"MIN": sa.func.MIN,
"MAX": sa.func.MAX,
}

@property
def fetch_value_predicate(self) -> str:
return "fix this!"
fetch_values_predicate = None

@property
def type(self) -> str:
Expand Down Expand Up @@ -785,17 +782,20 @@ def sql(self) -> str:
def columns(self) -> list[Any]:
raise NotImplementedError()

def get_fetch_values_predicate(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> TextClause:
raise NotImplementedError()

def get_extra_cache_keys(self, query_obj: dict[str, Any]) -> list[Hashable]:
raise NotImplementedError()

def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
raise NotImplementedError()

def get_fetch_values_predicate(
self,
template_processor: Optional[ # pylint: disable=unused-argument
BaseTemplateProcessor
] = None, # pylint: disable=unused-argument
) -> TextClause:
return self.fetch_values_predicate

def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
Expand Down Expand Up @@ -1341,36 +1341,34 @@ def get_time_filter( # pylint: disable=too-many-arguments
return and_(*l)

def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
cols = {}
for col in self.columns:
if isinstance(col, dict):
cols[col.get("column_name")] = col
else:
cols[col.column_name] = col

target_col = cols[column_name]
tp = None # todo(hughhhh): add back self.get_template_processor()
# always denormalize column name before querying for values
db_dialect = self.database.get_dialect()
denomalized_col_name = self.database.db_engine_spec.denormalize_name(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

denomarlized name will return the original value if it can't be adjusted by the engine:
preset-io/superset@60e1526/superset/db_engine_specs/base.py#L1950

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: denomalized -> denormalized

db_dialect, column_name
)
cols = {col.column_name: col for col in self.columns}
target_col = cols[denomalized_col_name]
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)

if isinstance(target_col, dict):
sql_column = sa.column(target_col.get("name"))
else:
sql_column = target_col

qry = sa.select([sql_column]).select_from(tbl).distinct()
qry = (
sa.select([target_col.get_sqla_col(template_processor=tp)])
.select_from(tbl)
.distinct()
)
if limit:
qry = qry.limit(limit)

if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))

with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)

df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
return df[denomalized_col_name].to_list()

def get_timestamp_expression(
self,
Expand Down Expand Up @@ -1942,7 +1940,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
)
having_clause_and += [self.text(having)]

if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(
self.get_fetch_values_predicate(template_processor=template_processor)
)
Expand Down
Loading