Skip to content

Commit

Permalink
Merge pull request #88 from M1ha-Shvn/bulk-value/issue-82
Browse files Browse the repository at this point in the history
Added ability to add BulkValue() into django expression
  • Loading branch information
M1ha-Shvn authored Feb 8, 2022
2 parents f4829a0 + 99fe982 commit 84b7c47
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 37 deletions.
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ There are 4 query helpers in this library. There parameters are unified and desc
[Func](https://docs.djangoproject.com/en/3.2/ref/models/expressions/#func-expressions) and their child classes.
It can not use annotations and tables other than updated model (like `F(a__b__c)`).
In create operations field default values is taken. If it is not provided, field default value is used.
Expression does not expect any value in `values` parameter and will ignore it if given.
Expression can contain `django_pg_bulk_update.set_functions.BulkValue()` expression in it.
If so, it will be replaced with field value passed in `values` parameter.
If expression does not contain `BulkValue` instances data, passed in `values` parameter for this key is ignored.

+ Function alias name
- 'eq', '='
Expand Down Expand Up @@ -197,6 +199,7 @@ There are 4 query helpers in this library. There parameters are unified and desc
from django.db import models, F
from django.db.models.functions import Upper
from django_pg_bulk_update import bulk_update, bulk_update_or_create, pdnf_clause
from django_pg_bulk_update.set_functions import BulkValue

# Test model
class TestModel(models.Model):
Expand Down Expand Up @@ -293,22 +296,24 @@ print(list(TestModel.objects.all().order_by("id").values("id", "name", "int_fiel

res = bulk_update_or_create(TestModel, [{
"id": 3,
"name": "_concat1"
"name": "_concat1",
"int_field": 3
}, {
"id": 4,
"name": "concat2"
}], set_functions={'name': '||', 'int_field': F('int_field') + 1})
"name": "concat2",
'int_field': 4
}], set_functions={'name': '||', 'int_field': F('int_field') + BulkValue()})

print(res)
# Outputs: 2

print(list(TestModel.objects.all().order_by("id").values("id", "name", "int_field")))
# Note: IntegerField defaults to 0 in create operations. So 0 + 1 = 1.
# Note: IntegerField defaults to 0 in create operations. So 0 + 4 = 4 for id 4.
# Outputs: [
# {"id": 1, "name": "updated1", "int_field": 2},
# {"id": 2, "name": "updated2", "int_field": 3},
# {"id": 3, "name": "incr_concat1", "int_field": 5},
# {"id": 4, "name": "concat2", "int_field": 1},
# {"id": 3, "name": "incr_concat1", "int_field": 7},
# {"id": 4, "name": "concat2", "int_field": 4},
# ]

# Find records where
Expand Down Expand Up @@ -357,7 +362,7 @@ TestModel.objects.pg_bulk_update([
# Any data here
], key_fields='id', set_functions=None, key_fields_ops=())

# Update only records with id gtreater than 5
# Update only records with id greater than 5
TestModel.objects.filter(id__gte=5).pg_bulk_update([
# Any data here
], key_fields='id', set_functions=None, key_fields_ops=())
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

setup(
name='django-pg-bulk-update',
version='3.5.1',
version='3.6.0',
packages=['django_pg_bulk_update'],
package_dir={'': 'src'},
url='https://github.com/M1hacka/django-pg-bulk-update',
Expand Down
125 changes: 109 additions & 16 deletions src/django_pg_bulk_update/set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,55 @@
}


if django_expressions_available():
from django.db.models.expressions import Expression

class BulkValue(Expression):
"""
A mixin which defines django Expression to be replaced by concrete django values
"""
def __init__(self, *args, **kwargs):
super(BulkValue, self).__init__(*args, **kwargs)

self._ready = False
self._value = None
self._val_as_param = False

def __ror__(self, other):
return super(BulkValue, self).__rand__(other)

def __rand__(self, other):
return super(BulkValue, self).__rand__(other)

def set_value(self, val, val_as_param): # type: (Any, bool) -> None
"""
Replaces fake initial data with real values, passed from query values parameter
:param val: Value to return in SQL
:param val_as_param: If flag is not set, value should be converted to string and inserted into query directly.
Otherwise a placeholder and query parameter will be used
:return: None
"""
self._value = val
self._val_as_param = val_as_param
self._ready = True

def as_sql(self, compiler, connection):
if not self._ready:
raise ValueError('BulkValue instance has not been initialized before using as_sql method')

if not self._val_as_param:
return str(self._value), []

from django.db.models.expressions import Value
return Value(self._value).as_sql(compiler, connection)

else:
# Dummy class for django before 1.8
class BulkValue:
def as_sql(self, compiler, connection):
raise NotImplementedError("This is not supported in your django version. Please, upgrade")


class AbstractSetFunction(AbstractFieldFormatter):
names = set()

Expand Down Expand Up @@ -189,34 +238,54 @@ def _get_field_column(self, field, with_table=False):


class DjangoSetFunction(AbstractSetFunction):
needs_value = False

def __init__(self, django_expression): # type: (BaseExpression) -> None # noqa: F821
if not django_expressions_available():
raise 'Django expressions are available since django 1.8, please upgrade'

self._django_expression = django_expression
self._needs_value = None

@property
def needs_value(self): # type: () -> bool
"""
If expression contains BulkValue() references, value is required, otherwise - not
:return: Boolean
"""
if self._needs_value is None:
def search_bulk_value_callback(expr):
self._needs_value = True
return expr

self._modify_expression_recursively(BulkValue, self._django_expression, search_bulk_value_callback)

return self._needs_value

@classmethod
def _modify_column_refs_recursively(cls, expr, callback):
# type: (BaseExpression, Callable) -> BaseExpression # noqa: F821
def _modify_expression_recursively(cls, target_expression_class, expr, callback):
# type: (BaseExpression, BaseExpression, Callable) -> BaseExpression # noqa: F821
"""
Recursively iterates expression, searching for Col references and calls callback for every found expression
:param target_expression_class: Instance of expression to search
:param expr: Expression to process
:param callback: Function to apply to every expression. Should take exactly 1 argument: Col instance
:param callback: Function to apply to every expression. Should take exactly 1 argument:
target_expression_class instance
:return: Processed expression
"""
from django.db.models.expressions import Col

if isinstance(expr, Col):
if isinstance(expr, target_expression_class):
return callback(expr)

if not hasattr(expr, 'get_source_expressions'):
return expr

src_expressions = expr.get_source_expressions()
if not src_expressions:
return expr

expr = expr.copy()
new_src_expressions = [cls._modify_column_refs_recursively(sub_expr, callback) for sub_expr in src_expressions]
new_src_expressions = [
cls._modify_expression_recursively(target_expression_class, sub_expr, callback)
for sub_expr in src_expressions
]
expr.set_source_expressions(new_src_expressions)
return expr

Expand All @@ -237,10 +306,10 @@ def replace_with_default_values(col): # type: (Col) -> BaseExpression
default_value = NULL_DEFAULTS.get(col.field.__class__.__name__)
return Value(default_value)

return cls._modify_column_refs_recursively(expr, replace_with_default_values)
return cls._modify_expression_recursively(Col, expr, replace_with_default_values)

@classmethod
def remove_aliases_from_expression(cls, expr):
def remove_aliases_from_expression(cls, expr): # type: (BaseExpression) -> BaseExpression # noqa: F821
"""
Removes table alias for functions which reference columns, if with_table flag is False
In django 3.1+ This can be achieved by alias_cols=False Query flag.
Expand All @@ -260,7 +329,26 @@ def as_sql(compiler, connection):
col.as_sql = as_sql
return col

return cls._modify_column_refs_recursively(expr, remove_alias_from_col)
return cls._modify_expression_recursively(Col, expr, remove_alias_from_col)

@classmethod
def replace_bulk_value_in_expression(cls, expr, val, val_as_param):
# type: (BaseExpression, Any, bool) -> BaseExpression # noqa: F821
"""
Removes table alias for functions which reference columns, if with_table flag is False
In django 3.1+ This can be achieved by alias_cols=False Query flag.
:param expr: Expression to process
:param val: Value passed to get_sql_value method
:param val_as_param: If flag is not set, value should be converted to string and inserted into query directly.
Otherwise, a placeholder and query parameter will be used
:return: Processed expression
"""
def set_bulk_value_real_data(expr):
expr.set_value(val, val_as_param)
return expr

cls._modify_expression_recursively(BulkValue, expr, set_bulk_value_real_data)
return expr

@classmethod
def get_query(cls, field, with_table=False, for_update=True): # type: (Field, bool, bool) -> Query
Expand All @@ -276,15 +364,18 @@ def get_query(cls, field, with_table=False, for_update=True): # type: (Field, b
return query

@classmethod
def resolve_expression(cls, field, expr, connection, with_table=False, for_update=True):
# type: (Field, Any, TDatabase, bool, bool) -> Tuple[SQLCompiler, BaseExpression] # noqa: F821
def resolve_expression(cls, field, expr, connection, val, val_as_param=False, with_table=False, for_update=True):
# type: (Field, Any, TDatabase, Any, bool, bool, bool) -> Tuple[SQLCompiler, BaseExpression] # noqa: F821
"""
Processes django expression, preparing it for SQL Generation
Note: expression resolve has been mostly copied from SQLUpdateCompiler.as_sql() method
and adopted for this function purposes
:param field: Django field expression will be applied to
:param expr: Expression to process
:param connection: Connection used to update data
:param val: Value passed to get_sql_value method
:param val_as_param: If flag is not set, value should be converted to string and inserted into query directly.
Otherwise, a placeholder and query parameter will be used
:param with_table: If flag is set, column name in sql is prefixed by table name
:param for_update: If flag is set, returns update sql. Otherwise - insert SQL
:return: A tuple of compiler used to format expression and result expression
Expand All @@ -293,6 +384,7 @@ def resolve_expression(cls, field, expr, connection, with_table=False, for_updat
compiler = query.get_compiler(connection=connection)

compiler.pre_sql_setup()
expr = cls.replace_bulk_value_in_expression(expr, val, val_as_param)
expr = expr.resolve_expression(query=query, allow_joins=False, for_save=True)
if expr.contains_aggregate:
raise FieldError(
Expand Down Expand Up @@ -321,12 +413,13 @@ def modify_create_params(self, model, key, kwargs, connection):

# Django is sets its built in defaults here. Let's replace column aliases with this library defaults
field = model._meta.get_field(key)
_, expr = self.resolve_expression(field, self._django_expression, connection, for_update=False)
_, expr = self.resolve_expression(field, self._django_expression, connection, kwargs.get(key), for_update=False)
kwargs[key] = expr
return kwargs

def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs):
compiler, expr = self.resolve_expression(field, self._django_expression, connection, with_table=with_table,
compiler, expr = self.resolve_expression(field, self._django_expression, connection, val,
val_as_param=val_as_param, with_table=with_table,
for_update=for_update)

# SQL forming is copied from SQLUpdateCompiler.as_sql() and adopted for this function purposes
Expand Down
7 changes: 4 additions & 3 deletions tests/test_bulk_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django_pg_bulk_update.compatibility import jsonb_available, hstore_available, array_available, tz_utc, \
django_expressions_available
from django_pg_bulk_update.query import bulk_create
from django_pg_bulk_update.set_functions import ConcatSetFunction
from django_pg_bulk_update.set_functions import ConcatSetFunction, BulkValue
from tests.models import TestModel, UpperCaseModel, AutoNowModel, TestModelWithSchema, UUIDFieldPrimaryModel


Expand Down Expand Up @@ -481,11 +481,12 @@ def test_auto_now_respects_override(self):
def test_django_expression(self):
# Default for IntegerField should be 0
from django.db.models import F
res = bulk_create(TestModel, [{'id': 11}, {'id': 12}], set_functions={'int_field': F('int_field') + 1})
res = bulk_create(TestModel, [{'id': 11, 'int_field': 1}, {'id': 12, 'int_field': 2}],
set_functions={'int_field': F('int_field') + BulkValue()})

self.assertEqual(2, res)
for instance in TestModel.objects.filter(pk__in={11, 12}):
self.assertEqual(1, instance.int_field)
self.assertEqual(instance.pk - 10, instance.int_field)
self.assertEqual('', instance.name)


Expand Down
8 changes: 4 additions & 4 deletions tests/test_bulk_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from django_pg_bulk_update.compatibility import jsonb_available, hstore_available, array_available, tz_utc, \
django_expressions_available
from django_pg_bulk_update.query import bulk_update
from django_pg_bulk_update.set_functions import ConcatSetFunction
from django_pg_bulk_update.set_functions import ConcatSetFunction, BulkValue
from tests.models import TestModel, RelationModel, UpperCaseModel, AutoNowModel, TestModelWithSchema, \
UUIDFieldPrimaryModel

Expand Down Expand Up @@ -697,12 +697,12 @@ def test_django_expression(self):
from django.db.models import F
from django.db.models.functions import Upper

res = bulk_update(TestModel, [{'id': 1}, {'id': 2}],
set_functions={'name': Upper('name'), 'int_field': F('int_field') + 1})
res = bulk_update(TestModel, [{'id': 1, 'int_field': 2}, {'id': 2, 'int_field': 4}],
set_functions={'name': Upper('name'), 'int_field': F('int_field') + BulkValue()})

self.assertEqual(2, res)
for instance in TestModel.objects.filter(pk__in={1, 2}):
self.assertEqual(instance.pk + 1, instance.int_field)
self.assertEqual(3 * instance.pk, instance.int_field)
self.assertEqual('TEST%d' % instance.pk, instance.name)


Expand Down
40 changes: 35 additions & 5 deletions tests/test_bulk_update_or_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django_pg_bulk_update.compatibility import jsonb_available, array_available, hstore_available, tz_utc, \
django_expressions_available
from django_pg_bulk_update.query import bulk_update_or_create
from django_pg_bulk_update.set_functions import ConcatSetFunction
from django_pg_bulk_update.set_functions import ConcatSetFunction, BulkValue
from django_pg_returning import ReturningQuerySet
from tests.compatibility import get_auto_now_date
from tests.models import TestModel, UniqueNotPrimary, UpperCaseModel, AutoNowModel, TestModelWithSchema, \
Expand Down Expand Up @@ -654,6 +654,33 @@ def test_example(self):
{"id": 4, "name": "concat2", "int_field": 1},
], list(TestModel.objects.all().order_by("id").values("id", "name", "int_field")))

@skipIf(not django_expressions_available(), "Django expressions are not supported")
def test_example_with_bulk_value(self):
# Skip bulk_create and bulk_update section (tested in other test), and init data as bulk_update_or_create start
TestModel.objects.bulk_create([
TestModel(pk=1, name="updated1", int_field=2),
TestModel(pk=2, name="updated2", int_field=3),
TestModel(pk=3, name="incr", int_field=4),
])

res = bulk_update_or_create(TestModel, [{
"id": 3,
"name": "_concat1",
"int_field": 3
}, {
"id": 4,
"name": "concat2",
"int_field": 4
}], set_functions={'name': '||', 'int_field': F('int_field') + BulkValue()})
self.assertEqual(2, res)

self.assertListEqual([
{"id": 1, "name": "updated1", "int_field": 2},
{"id": 2, "name": "updated2", "int_field": 3},
{"id": 3, "name": "incr_concat1", "int_field": 7},
{"id": 4, "name": "concat2", "int_field": 4},
], list(TestModel.objects.all().order_by("id").values("id", "name", "int_field")))


class TestSetFunctions(TestCase):
fixtures = ['test_model']
Expand Down Expand Up @@ -862,19 +889,22 @@ def test_django_expression(self):

res = bulk_update_or_create(TestModel, [{
'id': 1,
'int_field': 1
}, {
'id': 5,
'int_field': 5
}, {
'id': 11
}], set_functions={'int_field': F('int_field') + 1, 'name': Upper('name')})
'id': 11,
'int_field': 11
}], set_functions={'int_field': F('int_field') + BulkValue(), 'name': Upper('name')})

self.assertEqual(3, res)
for pk, name, int_field in TestModel.objects.all().order_by('id').values_list('id', 'name', 'int_field'):
if pk in {1, 5}:
self.assertEqual(pk + 1, int_field)
self.assertEqual(pk * 2, int_field)
self.assertEqual('TEST%d' % pk, name)
elif pk > 10:
self.assertEqual(1, int_field)
self.assertEqual(pk, int_field)
self.assertEqual('', name)
else:
self.assertEqual(pk, int_field)
Expand Down

0 comments on commit 84b7c47

Please sign in to comment.