Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to add BulkValue() into django expression #88

Merged
merged 1 commit into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ There are 4 query helpers in this library. There parameters are unified and desc
[Func](https://docs.djangoproject.com/en/3.2/ref/models/expressions/#func-expressions) and their child classes.
It can not use annotations and tables other than updated model (like `F(a__b__c)`).
In create operations field default values is taken. If it is not provided, field default value is used.
Expression does not expect any value in `values` parameter and will ignore it if given.
Expression can contain `django_pg_bulk_update.set_functions.BulkValue()` expression in it.
If so, it will be replaced with field value passed in `values` parameter.
If expression does not contain `BulkValue` instances data, passed in `values` parameter for this key is ignored.

+ Function alias name
- 'eq', '='
Expand Down Expand Up @@ -197,6 +199,7 @@ There are 4 query helpers in this library. There parameters are unified and desc
from django.db import models, F
from django.db.models.functions import Upper
from django_pg_bulk_update import bulk_update, bulk_update_or_create, pdnf_clause
from django_pg_bulk_update.set_functions import BulkValue

# Test model
class TestModel(models.Model):
Expand Down Expand Up @@ -293,22 +296,24 @@ print(list(TestModel.objects.all().order_by("id").values("id", "name", "int_fiel

res = bulk_update_or_create(TestModel, [{
"id": 3,
"name": "_concat1"
"name": "_concat1",
"int_field": 3
}, {
"id": 4,
"name": "concat2"
}], set_functions={'name': '||', 'int_field': F('int_field') + 1})
"name": "concat2",
'int_field': 4
}], set_functions={'name': '||', 'int_field': F('int_field') + BulkValue()})

print(res)
# Outputs: 2

print(list(TestModel.objects.all().order_by("id").values("id", "name", "int_field")))
# Note: IntegerField defaults to 0 in create operations. So 0 + 1 = 1.
# Note: IntegerField defaults to 0 in create operations. So 0 + 4 = 4 for id 4.
# Outputs: [
# {"id": 1, "name": "updated1", "int_field": 2},
# {"id": 2, "name": "updated2", "int_field": 3},
# {"id": 3, "name": "incr_concat1", "int_field": 5},
# {"id": 4, "name": "concat2", "int_field": 1},
# {"id": 3, "name": "incr_concat1", "int_field": 7},
# {"id": 4, "name": "concat2", "int_field": 4},
# ]

# Find records where
Expand Down Expand Up @@ -357,7 +362,7 @@ TestModel.objects.pg_bulk_update([
# Any data here
], key_fields='id', set_functions=None, key_fields_ops=())

# Update only records with id gtreater than 5
# Update only records with id greater than 5
TestModel.objects.filter(id__gte=5).pg_bulk_update([
# Any data here
], key_fields='id', set_functions=None, key_fields_ops=())
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

setup(
name='django-pg-bulk-update',
version='3.5.1',
version='3.6.0',
packages=['django_pg_bulk_update'],
package_dir={'': 'src'},
url='https://github.com/M1hacka/django-pg-bulk-update',
Expand Down
125 changes: 109 additions & 16 deletions src/django_pg_bulk_update/set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,55 @@
}


if django_expressions_available():
from django.db.models.expressions import Expression

class BulkValue(Expression):
"""
A mixin which defines django Expression to be replaced by concrete django values
"""
def __init__(self, *args, **kwargs):
super(BulkValue, self).__init__(*args, **kwargs)

self._ready = False
self._value = None
self._val_as_param = False

def __ror__(self, other):
return super(BulkValue, self).__rand__(other)

def __rand__(self, other):
return super(BulkValue, self).__rand__(other)

def set_value(self, val, val_as_param): # type: (Any, bool) -> None
"""
Replaces fake initial data with real values, passed from query values parameter
:param val: Value to return in SQL
:param val_as_param: If flag is not set, value should be converted to string and inserted into query directly.
Otherwise a placeholder and query parameter will be used
:return: None
"""
self._value = val
self._val_as_param = val_as_param
self._ready = True

def as_sql(self, compiler, connection):
if not self._ready:
raise ValueError('BulkValue instance has not been initialized before using as_sql method')

if not self._val_as_param:
return str(self._value), []

from django.db.models.expressions import Value
return Value(self._value).as_sql(compiler, connection)

else:
# Dummy class for django before 1.8
class BulkValue:
def as_sql(self, compiler, connection):
raise NotImplementedError("This is not supported in your django version. Please, upgrade")


class AbstractSetFunction(AbstractFieldFormatter):
names = set()

Expand Down Expand Up @@ -189,34 +238,54 @@ def _get_field_column(self, field, with_table=False):


class DjangoSetFunction(AbstractSetFunction):
needs_value = False

def __init__(self, django_expression): # type: (BaseExpression) -> None # noqa: F821
if not django_expressions_available():
raise 'Django expressions are available since django 1.8, please upgrade'

self._django_expression = django_expression
self._needs_value = None

@property
def needs_value(self): # type: () -> bool
"""
If expression contains BulkValue() references, value is required, otherwise - not
:return: Boolean
"""
if self._needs_value is None:
def search_bulk_value_callback(expr):
self._needs_value = True
return expr

self._modify_expression_recursively(BulkValue, self._django_expression, search_bulk_value_callback)

return self._needs_value

@classmethod
def _modify_column_refs_recursively(cls, expr, callback):
# type: (BaseExpression, Callable) -> BaseExpression # noqa: F821
def _modify_expression_recursively(cls, target_expression_class, expr, callback):
# type: (BaseExpression, BaseExpression, Callable) -> BaseExpression # noqa: F821
"""
Recursively iterates expression, searching for Col references and calls callback for every found expression
:param target_expression_class: Instance of expression to search
:param expr: Expression to process
:param callback: Function to apply to every expression. Should take exactly 1 argument: Col instance
:param callback: Function to apply to every expression. Should take exactly 1 argument:
target_expression_class instance
:return: Processed expression
"""
from django.db.models.expressions import Col

if isinstance(expr, Col):
if isinstance(expr, target_expression_class):
return callback(expr)

if not hasattr(expr, 'get_source_expressions'):
return expr

src_expressions = expr.get_source_expressions()
if not src_expressions:
return expr

expr = expr.copy()
new_src_expressions = [cls._modify_column_refs_recursively(sub_expr, callback) for sub_expr in src_expressions]
new_src_expressions = [
cls._modify_expression_recursively(target_expression_class, sub_expr, callback)
for sub_expr in src_expressions
]
expr.set_source_expressions(new_src_expressions)
return expr

Expand All @@ -237,10 +306,10 @@ def replace_with_default_values(col): # type: (Col) -> BaseExpression
default_value = NULL_DEFAULTS.get(col.field.__class__.__name__)
return Value(default_value)

return cls._modify_column_refs_recursively(expr, replace_with_default_values)
return cls._modify_expression_recursively(Col, expr, replace_with_default_values)

@classmethod
def remove_aliases_from_expression(cls, expr):
def remove_aliases_from_expression(cls, expr): # type: (BaseExpression) -> BaseExpression # noqa: F821
"""
Removes table alias for functions which reference columns, if with_table flag is False
In django 3.1+ This can be achieved by alias_cols=False Query flag.
Expand All @@ -260,7 +329,26 @@ def as_sql(compiler, connection):
col.as_sql = as_sql
return col

return cls._modify_column_refs_recursively(expr, remove_alias_from_col)
return cls._modify_expression_recursively(Col, expr, remove_alias_from_col)

@classmethod
def replace_bulk_value_in_expression(cls, expr, val, val_as_param):
# type: (BaseExpression, Any, bool) -> BaseExpression # noqa: F821
"""
Removes table alias for functions which reference columns, if with_table flag is False
In django 3.1+ This can be achieved by alias_cols=False Query flag.
:param expr: Expression to process
:param val: Value passed to get_sql_value method
:param val_as_param: If flag is not set, value should be converted to string and inserted into query directly.
Otherwise, a placeholder and query parameter will be used
:return: Processed expression
"""
def set_bulk_value_real_data(expr):
expr.set_value(val, val_as_param)
return expr

cls._modify_expression_recursively(BulkValue, expr, set_bulk_value_real_data)
return expr

@classmethod
def get_query(cls, field, with_table=False, for_update=True): # type: (Field, bool, bool) -> Query
Expand All @@ -276,15 +364,18 @@ def get_query(cls, field, with_table=False, for_update=True): # type: (Field, b
return query

@classmethod
def resolve_expression(cls, field, expr, connection, with_table=False, for_update=True):
# type: (Field, Any, TDatabase, bool, bool) -> Tuple[SQLCompiler, BaseExpression] # noqa: F821
def resolve_expression(cls, field, expr, connection, val, val_as_param=False, with_table=False, for_update=True):
# type: (Field, Any, TDatabase, Any, bool, bool, bool) -> Tuple[SQLCompiler, BaseExpression] # noqa: F821
"""
Processes django expression, preparing it for SQL Generation
Note: expression resolve has been mostly copied from SQLUpdateCompiler.as_sql() method
and adopted for this function purposes
:param field: Django field expression will be applied to
:param expr: Expression to process
:param connection: Connection used to update data
:param val: Value passed to get_sql_value method
:param val_as_param: If flag is not set, value should be converted to string and inserted into query directly.
Otherwise, a placeholder and query parameter will be used
:param with_table: If flag is set, column name in sql is prefixed by table name
:param for_update: If flag is set, returns update sql. Otherwise - insert SQL
:return: A tuple of compiler used to format expression and result expression
Expand All @@ -293,6 +384,7 @@ def resolve_expression(cls, field, expr, connection, with_table=False, for_updat
compiler = query.get_compiler(connection=connection)

compiler.pre_sql_setup()
expr = cls.replace_bulk_value_in_expression(expr, val, val_as_param)
expr = expr.resolve_expression(query=query, allow_joins=False, for_save=True)
if expr.contains_aggregate:
raise FieldError(
Expand Down Expand Up @@ -321,12 +413,13 @@ def modify_create_params(self, model, key, kwargs, connection):

# Django is sets its built in defaults here. Let's replace column aliases with this library defaults
field = model._meta.get_field(key)
_, expr = self.resolve_expression(field, self._django_expression, connection, for_update=False)
_, expr = self.resolve_expression(field, self._django_expression, connection, kwargs.get(key), for_update=False)
kwargs[key] = expr
return kwargs

def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs):
compiler, expr = self.resolve_expression(field, self._django_expression, connection, with_table=with_table,
compiler, expr = self.resolve_expression(field, self._django_expression, connection, val,
val_as_param=val_as_param, with_table=with_table,
for_update=for_update)

# SQL forming is copied from SQLUpdateCompiler.as_sql() and adopted for this function purposes
Expand Down
7 changes: 4 additions & 3 deletions tests/test_bulk_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django_pg_bulk_update.compatibility import jsonb_available, hstore_available, array_available, tz_utc, \
django_expressions_available
from django_pg_bulk_update.query import bulk_create
from django_pg_bulk_update.set_functions import ConcatSetFunction
from django_pg_bulk_update.set_functions import ConcatSetFunction, BulkValue
from tests.models import TestModel, UpperCaseModel, AutoNowModel, TestModelWithSchema, UUIDFieldPrimaryModel


Expand Down Expand Up @@ -481,11 +481,12 @@ def test_auto_now_respects_override(self):
def test_django_expression(self):
# Default for IntegerField should be 0
from django.db.models import F
res = bulk_create(TestModel, [{'id': 11}, {'id': 12}], set_functions={'int_field': F('int_field') + 1})
res = bulk_create(TestModel, [{'id': 11, 'int_field': 1}, {'id': 12, 'int_field': 2}],
set_functions={'int_field': F('int_field') + BulkValue()})

self.assertEqual(2, res)
for instance in TestModel.objects.filter(pk__in={11, 12}):
self.assertEqual(1, instance.int_field)
self.assertEqual(instance.pk - 10, instance.int_field)
self.assertEqual('', instance.name)


Expand Down
8 changes: 4 additions & 4 deletions tests/test_bulk_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from django_pg_bulk_update.compatibility import jsonb_available, hstore_available, array_available, tz_utc, \
django_expressions_available
from django_pg_bulk_update.query import bulk_update
from django_pg_bulk_update.set_functions import ConcatSetFunction
from django_pg_bulk_update.set_functions import ConcatSetFunction, BulkValue
from tests.models import TestModel, RelationModel, UpperCaseModel, AutoNowModel, TestModelWithSchema, \
UUIDFieldPrimaryModel

Expand Down Expand Up @@ -697,12 +697,12 @@ def test_django_expression(self):
from django.db.models import F
from django.db.models.functions import Upper

res = bulk_update(TestModel, [{'id': 1}, {'id': 2}],
set_functions={'name': Upper('name'), 'int_field': F('int_field') + 1})
res = bulk_update(TestModel, [{'id': 1, 'int_field': 2}, {'id': 2, 'int_field': 4}],
set_functions={'name': Upper('name'), 'int_field': F('int_field') + BulkValue()})

self.assertEqual(2, res)
for instance in TestModel.objects.filter(pk__in={1, 2}):
self.assertEqual(instance.pk + 1, instance.int_field)
self.assertEqual(3 * instance.pk, instance.int_field)
self.assertEqual('TEST%d' % instance.pk, instance.name)


Expand Down
40 changes: 35 additions & 5 deletions tests/test_bulk_update_or_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django_pg_bulk_update.compatibility import jsonb_available, array_available, hstore_available, tz_utc, \
django_expressions_available
from django_pg_bulk_update.query import bulk_update_or_create
from django_pg_bulk_update.set_functions import ConcatSetFunction
from django_pg_bulk_update.set_functions import ConcatSetFunction, BulkValue
from django_pg_returning import ReturningQuerySet
from tests.compatibility import get_auto_now_date
from tests.models import TestModel, UniqueNotPrimary, UpperCaseModel, AutoNowModel, TestModelWithSchema, \
Expand Down Expand Up @@ -654,6 +654,33 @@ def test_example(self):
{"id": 4, "name": "concat2", "int_field": 1},
], list(TestModel.objects.all().order_by("id").values("id", "name", "int_field")))

@skipIf(not django_expressions_available(), "Django expressions are not supported")
def test_example_with_bulk_value(self):
# Skip bulk_create and bulk_update section (tested in other test), and init data as bulk_update_or_create start
TestModel.objects.bulk_create([
TestModel(pk=1, name="updated1", int_field=2),
TestModel(pk=2, name="updated2", int_field=3),
TestModel(pk=3, name="incr", int_field=4),
])

res = bulk_update_or_create(TestModel, [{
"id": 3,
"name": "_concat1",
"int_field": 3
}, {
"id": 4,
"name": "concat2",
"int_field": 4
}], set_functions={'name': '||', 'int_field': F('int_field') + BulkValue()})
self.assertEqual(2, res)

self.assertListEqual([
{"id": 1, "name": "updated1", "int_field": 2},
{"id": 2, "name": "updated2", "int_field": 3},
{"id": 3, "name": "incr_concat1", "int_field": 7},
{"id": 4, "name": "concat2", "int_field": 4},
], list(TestModel.objects.all().order_by("id").values("id", "name", "int_field")))


class TestSetFunctions(TestCase):
fixtures = ['test_model']
Expand Down Expand Up @@ -862,19 +889,22 @@ def test_django_expression(self):

res = bulk_update_or_create(TestModel, [{
'id': 1,
'int_field': 1
}, {
'id': 5,
'int_field': 5
}, {
'id': 11
}], set_functions={'int_field': F('int_field') + 1, 'name': Upper('name')})
'id': 11,
'int_field': 11
}], set_functions={'int_field': F('int_field') + BulkValue(), 'name': Upper('name')})

self.assertEqual(3, res)
for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'):
if pk in {1, 5}:
self.assertEqual(pk + 1, int_field)
self.assertEqual(pk * 2, int_field)
self.assertEqual('TEST%d' % pk, name)
elif pk > 10:
self.assertEqual(1, int_field)
self.assertEqual(pk, int_field)
self.assertEqual('', name)
else:
self.assertEqual(pk, int_field)
Expand Down