Skip to content

Commit

Permalink
Merge "Render labels in autogenerate index" into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzeek authored and Gerrit Code Review committed Feb 12, 2025
2 parents 7b54f1c + 762b2d2 commit f1ae53e
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 28 deletions.
34 changes: 20 additions & 14 deletions alembic/autogenerate/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import Label
from sqlalchemy.sql.elements import quoted_name

from .. import util
Expand Down Expand Up @@ -584,23 +585,28 @@ def _render_potential_expr(
value: Any,
autogen_context: AutogenContext,
*,
wrap_in_text: bool = True,
wrap_in_element: bool = True,
is_server_default: bool = False,
is_index: bool = False,
) -> str:
if isinstance(value, sql.ClauseElement):
if wrap_in_text:
template = "%(prefix)stext(%(sql)r)"
sql_text = autogen_context.migration_context.impl.render_ddl_sql_expr(
value, is_server_default=is_server_default, is_index=is_index
)
if wrap_in_element:
prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
element = "literal_column" if is_index else "text"
value_str = f"{prefix}{element}({sql_text!r})"
if (
is_index
and isinstance(value, Label)
and type(value.name) is str
):
return value_str + f".label({value.name!r})"
else:
return value_str
else:
template = "%(sql)r"

return template % {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"sql": autogen_context.migration_context.impl.render_ddl_sql_expr(
value, is_server_default=is_server_default, is_index=is_index
),
}

return repr(sql_text)
else:
return repr(value)

Expand Down Expand Up @@ -787,7 +793,7 @@ def _render_computed(
computed: Computed, autogen_context: AutogenContext
) -> str:
text = _render_potential_expr(
computed.sqltext, autogen_context, wrap_in_text=False
computed.sqltext, autogen_context, wrap_in_element=False
)

kwargs = {}
Expand Down Expand Up @@ -1101,7 +1107,7 @@ def _render_check_constraint(
else ""
),
"sqltext": _render_potential_expr(
constraint.sqltext, autogen_context, wrap_in_text=False
constraint.sqltext, autogen_context, wrap_in_element=False
),
}

Expand Down
2 changes: 1 addition & 1 deletion alembic/ddl/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,5 +846,5 @@ def _render_potential_column(
return render._render_potential_expr(
value,
autogen_context,
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
wrap_in_element=isinstance(value, (TextClause, FunctionElement)),
)
3 changes: 1 addition & 2 deletions alembic/op.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ if TYPE_CHECKING:
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.expression import TableClause
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Computed
from sqlalchemy.sql.schema import Identity
Expand Down Expand Up @@ -650,7 +649,7 @@ def create_foreign_key(
def create_index(
index_name: Optional[str],
table_name: str,
columns: Sequence[Union[str, TextClause, Function[Any]]],
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
*,
schema: Optional[str] = None,
unique: bool = False,
Expand Down
3 changes: 1 addition & 2 deletions alembic/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.expression import TableClause
from sqlalchemy.sql.expression import TextClause
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Computed
from sqlalchemy.sql.schema import Identity
Expand Down Expand Up @@ -1074,7 +1073,7 @@ def create_index(
self,
index_name: Optional[str],
table_name: str,
columns: Sequence[Union[str, TextClause, Function[Any]]],
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
*,
schema: Optional[str] = None,
unique: bool = False,
Expand Down
3 changes: 1 addition & 2 deletions alembic/operations/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Computed
Expand Down Expand Up @@ -933,7 +932,7 @@ def create_index(
operations: Operations,
index_name: Optional[str],
table_name: str,
columns: Sequence[Union[str, TextClause, Function[Any]]],
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
*,
schema: Optional[str] = None,
unique: bool = False,
Expand Down
7 changes: 7 additions & 0 deletions docs/build/unreleased/1603.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. change::
:tags: usecase, autogenerate
:tickets: 1603

Index autogenerate will now render labels for expressions
that use them. This is useful when applying operator classes
in PostgreSQL that can be keyed on the label name.
38 changes: 32 additions & 6 deletions tests/test_autogen_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,32 @@ def test_render_add_index(self):
"['active', 'code'], unique=False)",
)

def test_render_add_index_fn(self):
t = self.table(Column("other", String(100)))
idx = Index("test_fn_idx", t.c.code + t.c.other)
op_obj = ops.CreateIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_fn_idx', 'test', "
"[sa.literal_column('code || other')], unique=False)",
)

def test_render_add_index_label(self):
t = self.table(Column("other", String(100)))
idx = Index(
"test_fn_idx",
(t.c.code + t.c.other).label("foo"),
t.c.id.label("bar"),
)
op_obj = ops.CreateIndexOp.from_index(idx)
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_fn_idx', 'test', ["
"sa.literal_column('code || other').label('foo'), "
"sa.literal_column('id').label('bar')"
"], unique=False)",
)

def test_render_add_index_if_not_exists(self):
"""
autogenerate.render._add_index
Expand Down Expand Up @@ -170,7 +196,7 @@ def test_render_add_index_text(self):
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_active_code_idx', 'test', "
"['active', sa.text('lower(code)')], unique=False)",
"['active', sa.literal_column('lower(code)')], unique=False)",
)
op_obj_rev = op_obj.reverse()
eq_ignore_whitespace(
Expand All @@ -186,7 +212,7 @@ def test_render_add_index_func(self):
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_lower_code_idx', 'test', "
"[sa.text('lower(code)')], unique=False)",
"[sa.literal_column('lower(code)')], unique=False)",
)
op_obj_rev = op_obj.reverse()
eq_ignore_whitespace(
Expand All @@ -202,7 +228,7 @@ def test_render_add_index_cast(self):
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_lower_code_idx', 'test', "
"[sa.text('CAST(code AS VARCHAR)')], unique=False)",
"[sa.literal_column('CAST(code AS VARCHAR)')], unique=False)",
)

def test_render_add_index_desc(self):
Expand All @@ -212,7 +238,7 @@ def test_render_add_index_desc(self):
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj),
"op.create_index('test_desc_code_idx', 'test', "
"[sa.text('code DESC')], unique=False)",
"[sa.literal_column('code DESC')], unique=False)",
)

def test_drop_index(self):
Expand Down Expand Up @@ -256,7 +282,7 @@ def test_drop_index_text(self):
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj_rev),
"op.create_index('test_active_code_idx', 'test', "
"['active', sa.text('lower(code)')], unique=False)",
"['active', sa.literal_column('lower(code)')], unique=False)",
)

def test_drop_index_func(self):
Expand All @@ -274,7 +300,7 @@ def test_drop_index_func(self):
eq_ignore_whitespace(
autogenerate.render_op_text(self.autogen_context, op_obj_rev),
"op.create_index('test_lower_code_idx', 'test', "
"[sa.text('lower(code)')], unique=False)",
"[sa.literal_column('lower(code)')], unique=False)",
)

@testing.emits_warning("Can't validate argument ")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ def test_jsonb_expression_in_index(self):
ops.CreateIndexOp.from_index(idx),
),
"op.create_index('my_idx', 'tbl', "
"[sa.text(\"(c ->> 'foo')\")], unique=False)",
"[sa.literal_column(\"(c ->> 'foo')\")], unique=False)",
)

@config.requirements.nulls_not_distinct_sa
Expand Down

0 comments on commit f1ae53e

Please sign in to comment.