Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for SQLAlchemy 1.4 #260

Merged
merged 10 commits into from
Jan 18, 2022
73 changes: 73 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
name: "Test"

on:
push:
paths-ignore:
- "docs/**"
pull_request:
paths-ignore:
- "docs/**"
schedule:
- cron: '40 1 * * 3'


jobs:
test:
name: test-python${{ matrix.python-version }}-sa${{ matrix.sqlalchemy-version }}-${{ matrix.db-engine }}
strategy:
matrix:
python-version:
# - "2.7"
# - "3.4"
# - "3.5"
# - "3.6"
# - "3.7"
- "3.8"
# - "3.9"
# - "3.10"
# - "pypy-3.7"
sqlalchemy-version:
- "<1.4"
- ">=1.4"
db-engine:
- sqlite
- postgres
- postgres-native
- mysql
runs-on: ubuntu-latest
services:
mysql:
image: mysql
ports:
- 3306:3306
env:
MYSQL_DATABASE: sqlalchemy_continuum_test
MYSQL_ALLOW_EMPTY_PASSWORD: yes
options: >-
--health-cmd "mysqladmin ping"
--health-interval 5s
--health-timeout 2s
--health-retries 3
postgres:
image: postgres
ports:
- 5432:5432
env:
POSTGRES_PASSWORD: postgres
POSTGRES_DB: sqlalchemy_continuum_test
options: >-
--health-cmd pg_isready
--health-interval 5s
--health-timeout 2s
--health-retries 3
steps:
- uses: actions/checkout@v1
- name: Install sqlalchemy
run: pip3 install 'sqlalchemy${{ matrix.sqlalchemy-version }}'
- name: Build
run: pip3 install -e '.[test]'
- name: Run tests
run: pytest
env:
DB: ${{ matrix.db-engine }}

2 changes: 1 addition & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_versioning(

make_versioned(options=options)

dns = 'postgres://postgres@localhost/sqlalchemy_continuum_test'
dns = 'postgresql://postgres:postgres@localhost/sqlalchemy_continuum_test'
versioning_manager.plugins = plugins
versioning_manager.transaction_cls = transaction_cls
versioning_manager.user_cls = user_cls
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def get_version():
'pytest>=2.3.5',
'flexmock>=0.9.7',
'psycopg2>=2.4.6',
'PyMySQL==0.6.1',
'PyMySQL>=0.8.0',
'six>=1.4.0'
],
'anyjson': ['anyjson>=0.3.3'],
'flask': ['Flask>=0.9'],
'flask-login': ['Flask-Login>=0.2.9'],
'flask-sqlalchemy': ['Flask-SQLAlchemy>=1.0'],
'flexmock': ['flexmock>=0.9.7'],
'i18n': ['SQLAlchemy-i18n>=0.8.4'],
'i18n': ['SQLAlchemy-i18n>=0.8.4,!=1.1.0'],
}


Expand Down
14 changes: 14 additions & 0 deletions sqlalchemy_continuum/builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import copy
from inspect import getmro
from functools import wraps

import sqlalchemy as sa
from sqlalchemy_utils.functions import get_declarative_base
Expand All @@ -10,6 +11,18 @@
from .table_builder import TableBuilder


def prevent_reentry(handler):
in_handler = False
@wraps(handler)
def check_reentry(*args, **kwargs):
nonlocal in_handler
if in_handler:
return
in_handler = True
handler(*args, **kwargs)
in_handler = False
return check_reentry

class Builder(object):
def build_triggers(self):
"""
Expand Down Expand Up @@ -141,6 +154,7 @@ def build_transaction_class(self):
self.manager.create_transaction_model()
self.manager.plugins.after_build_tx_class(self.manager)

@prevent_reentry
def configure_versioned_classes(self):
"""
Configures all versioned classes that were collected during
Expand Down
6 changes: 5 additions & 1 deletion sqlalchemy_continuum/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ def __call__(self, manager):
Create model class but only if it doesn't already exist
in declarative model registry.
"""
registry = manager.declarative_base._decl_class_registry
Base = manager.declarative_base
try:
registry = Base.registry._class_registry
except AttributeError: # SQLAlchemy < 1.4
registry = Base._decl_class_registry
if self.model_name not in registry:
return self.create_class(manager)
return registry[self.model_name]
4 changes: 2 additions & 2 deletions sqlalchemy_continuum/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _transaction_id_subquery(self, obj, next_or_prev='next', alias=None):
func = sa.func.max

if alias is None:
alias = sa.orm.aliased(obj)
alias = sa.orm.aliased(obj.__class__)
table = alias.__table__
if hasattr(alias, 'c'):
attrs = alias.c
Expand Down Expand Up @@ -117,7 +117,7 @@ def _index_query(self, obj):
Returns the query needed for fetching the index of this record relative
to version history.
"""
alias = sa.orm.aliased(obj)
alias = sa.orm.aliased(obj.__class__)

subquery = (
sa.select([sa.func.count('1')], from_obj=[alias.__table__])
Expand Down
6 changes: 5 additions & 1 deletion sqlalchemy_continuum/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ class Transaction(

if manager.user_cls:
user_cls = manager.user_cls
registry = manager.declarative_base._decl_class_registry
Base = manager.declarative_base
try:
registry = Base.registry._class_registry
except AttributeError: # SQLAlchemy < 1.4
registry = Base._decl_class_registry

if isinstance(user_cls, six.string_types):
try:
Expand Down
4 changes: 2 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def log_sql(

def get_dns_from_driver(driver):
if driver == 'postgres':
return 'postgres://postgres@localhost/sqlalchemy_continuum_test'
return 'postgresql://postgres:postgres@localhost/sqlalchemy_continuum_test'
elif driver == 'mysql':
return 'mysql+pymysql://travis@localhost/sqlalchemy_continuum_test'
return 'mysql+pymysql://root@localhost/sqlalchemy_continuum_test'
elif driver == 'sqlite':
return 'sqlite:///:memory:'
else:
Expand Down
1 change: 1 addition & 0 deletions tests/inheritance/test_single_table_inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class TextItem(self.Model):

__mapper_args__ = {
'polymorphic_on': discriminator,
'polymorphic_identity': u'base',
'with_polymorphic': '*'
}

Expand Down
2 changes: 1 addition & 1 deletion tests/relationships/test_association_table_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PublishedArticle(self.Model):
__tablename__ = 'published_article'
__table_args__ = (
PrimaryKeyConstraint("article_id", "author_id"),
{'useexisting': True}
{'keep_existing': True}
)

article_id = sa.Column(sa.Integer, sa.ForeignKey('article.id'))
Expand Down
5 changes: 5 additions & 0 deletions tests/test_mapper_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from pytest import mark
from packaging import version

import sqlalchemy as sa
from sqlalchemy_continuum import version_class
from tests import TestCase
Expand Down Expand Up @@ -29,6 +32,7 @@ def test_supports_column_prefix(self):
assert self.TextItem._id


@mark.skipif("version.parse(sa.__version__) >= version.parse('1.4')")
class TestOrderByWithStringArg(TestCase):
def create_models(self):
class TextItem(self.Model):
Expand All @@ -55,6 +59,7 @@ def test_reflects_order_by(self):
assert self.TextItemVersion.__mapper_args__['order_by'] == 'id'


@mark.skipif("version.parse(sa.__version__) >= version.parse('1.4')")
class TestOrderByWithInstrumentedAttribute(TestCase):
def create_models(self):
class TextItem(self.Model):
Expand Down