Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#65] Add _bulk_create parameter to make #134

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

### Added

- Add new `_bulk_create` parameter to `make` for using Django manager `bulk_create` with `_quantity` [PR #134](https://github.com/model-bakers/model_bakery/pull/134)

### Changed

- Type hinting fixed for Recipe "_model" parameter
Expand Down
11 changes: 10 additions & 1 deletion model_bakery/baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def make(
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
**attrs: Any
):
"""Create a persisted instance from a given model its associated models.
Expand All @@ -68,7 +69,14 @@ def make(
if _valid_quantity(_quantity):
raise InvalidQuantityException

if _quantity:
if _quantity and _bulk_create:
return baker.model._base_manager.bulk_create(
[
baker.prepare(_save_kwargs=_save_kwargs, **attrs)
for _ in range(_quantity)
]
)
elif _quantity:
return [
baker.make(
_save_kwargs=_save_kwargs,
Expand All @@ -77,6 +85,7 @@ def make(
)
for _ in range(_quantity)
]

return baker.make(
_save_kwargs=_save_kwargs, _refresh_after_create=_refresh_after_create, **attrs
)
Expand Down
6 changes: 6 additions & 0 deletions tests/generic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,9 @@ class Meta(object):

class SubclassOfAbstract(AbstractModel):
height = models.IntegerField()


class NonStandardManager(models.Model):
name = models.CharField(max_length=30)

manager = models.Manager()
83 changes: 79 additions & 4 deletions tests/test_baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from django.conf import settings
from django.db import connection
from django.db.models import Manager
from django.db.models.signals import m2m_changed
from django.test import TestCase
Expand All @@ -29,6 +30,45 @@ def test_import_seq_from_baker():
pytest.fail("{} raised".format(ImportError.__name__))


class QueryCount:
"""
Keep track of db calls.

Example:
========

qc = QueryCount()

with qc.start_count():
MyModel.objects.get(pk=1)
MyModel.objects.create()

qc.count # 2

"""

def __init__(self):
self.count = 0

def __call__(self, execute, sql, params, many, context):
"""
`django.db.connection.execute_wrapper` callback

https://docs.djangoproject.com/en/3.1/topics/db/instrumentation/
"""
self.count += 1
execute(sql, params, many, context)

def start_count(self):
"""
Reset query count to 0 and return context manager for wrapping db
queries.
"""
self.count = 0

return connection.execute_wrapper(self)


class TestsModelFinder:
def test_unicode_regression(self):
obj = baker.prepare("generic.Person")
Expand Down Expand Up @@ -114,11 +154,46 @@ def test_multiple_inheritance_creation(self):
@pytest.mark.django_db
class TestsBakerRepeatedCreatesSimpleModel:
def test_make_should_create_objects_respecting_quantity_parameter(self):
baker.make(models.Person, _quantity=5)
assert models.Person.objects.count() == 5
queries = QueryCount()

people = baker.make(models.Person, _quantity=5, name="George Washington")
assert all(p.name == "George Washington" for p in people)
with queries.start_count():
baker.make(models.Person, _quantity=5)
assert queries.count == 5
assert models.Person.objects.count() == 5

with queries.start_count():
people = baker.make(models.Person, _quantity=5, name="George Washington")
assert all(p.name == "George Washington" for p in people)
assert queries.count == 5

def test_make_quantity_respecting_bulk_create_parameter(self):
queries = QueryCount()

with queries.start_count():
baker.make(models.Person, _quantity=5, _bulk_create=True)
assert queries.count == 1
assert models.Person.objects.count() == 5

with queries.start_count():
people = baker.make(
models.Person, name="George Washington", _quantity=5, _bulk_create=True
)
assert all(p.name == "George Washington" for p in people)
assert queries.count == 1

with queries.start_count():
baker.make(models.NonStandardManager, _quantity=3, _bulk_create=True)
assert queries.count == 1
assert getattr(models.NonStandardManager, "objects", None) is None
assert (
models.NonStandardManager._base_manager
== models.NonStandardManager.manager
)
assert (
models.NonStandardManager._default_manager
== models.NonStandardManager.manager
)
assert models.NonStandardManager.manager.count() == 3

def test_make_raises_correct_exception_if_invalid_quantity(self):
with pytest.raises(InvalidQuantityException):
Expand Down