Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error when adding custom fields generators to bakery via settings file #58

Merged
merged 10 commits into from
Jan 22, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added

### Changed
<<<<<<< HEAD
- Improve code comments [PR #31](https://github.com/model-bakers/model_bakery/pull/31)
- Switch to tox-travis [PR #43](https://github.com/model-bakers/model_bakery/pull/43)
- Add black job [PR #42](https://github.com/model-bakers/model_bakery/pull/42)
- README.md instead of rst [PR #44](https://github.com/model-bakers/model_bakery/pull/44)
- Add Django 3.0 and Python 3.8 to CI [PR #48](https://github.com/model-bakers/model_bakery/pull/48/)
- Add `start` argument to `baker.seq` [PR #56](https://github.com/model-bakers/model_bakery/pull/56)
- Fixes bug when registering custom fields generator via `settings.py` [PR #58](https://github.com/model-bakers/model_bakery/pull/58)

### Removed

Expand Down
15 changes: 9 additions & 6 deletions model_bakery/baker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from os.path import dirname, join

from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.contrib import contenttypes
from django.apps import apps
from django.contrib.contenttypes.fields import GenericRelation

from django.db.models.base import ModelBase
from django.db.models import (
Expand Down Expand Up @@ -403,6 +402,8 @@ def is_rel_field(x):
]

def _skip_field(self, field):
from django.contrib.contenttypes.fields import GenericRelation

# check for fill optional argument
if isinstance(self.fill_in_optional, bool):
field.fill_optional = self.fill_in_optional
Expand Down Expand Up @@ -488,14 +489,16 @@ def generate_value(self, field, commit=True):
`attr_mapping` and `type_mapping` can be defined easily overwriting the
model.
"""
is_content_type_fk = isinstance(field, ForeignKey) and issubclass(
self._remote_field(field).model, contenttypes.models.ContentType
)

if field.name in self.attr_mapping:
generator = self.attr_mapping[field.name]
elif getattr(field, "choices"):
generator = random_gen.gen_from_choices(field.choices)
elif isinstance(field, ForeignKey) and issubclass(
self._remote_field(field).model, ContentType
):
generator = self.type_mapping[ContentType]
elif is_content_type_fk:
generator = self.type_mapping[contenttypes.models.ContentType]
elif generators.get(field.__class__):
generator = generators.get(field.__class__)
elif field.__class__ in self.type_mapping:
Expand Down
10 changes: 6 additions & 4 deletions model_bakery/generators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from django.contrib.contenttypes.models import ContentType
from django.db.models import (
BigIntegerField,
BinaryField,
Expand Down Expand Up @@ -30,7 +29,6 @@
)

from . import random_gen
from .gis import default_gis_mapping
from .utils import import_from_str

try:
Expand Down Expand Up @@ -88,7 +86,6 @@
FileField: random_gen.gen_file_field,
ImageField: random_gen.gen_image_field,
DurationField: random_gen.gen_interval,
ContentType: random_gen.gen_content_type,
}

if ArrayField:
Expand All @@ -105,11 +102,16 @@
default_mapping[CITextField] = random_gen.gen_text

# Add GIS fields
default_mapping.update(default_gis_mapping)


def get_type_mapping():
from django.contrib.contenttypes.models import ContentType
from .gis import default_gis_mapping

mapping = default_mapping.copy()
mapping[ContentType] = random_gen.gen_content_type
default_mapping.update(default_gis_mapping)

return mapping.copy()


Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def pytest_configure():
] + installed_apps
else:
raise NotImplementedError("Tests for % are not supported", test_db)

settings.configure(
DATABASES={"default": {"ENGINE": db_engine, "NAME": db_name}},
INSTALLED_APPS=installed_apps,
Expand All @@ -38,4 +39,12 @@ def pytest_configure():
MIDDLEWARE=(),
USE_TZ=os.environ.get("USE_TZ", False),
)

from model_bakery import baker

def gen_same_text():
return "always the same text"

baker.generators.add("tests.generic.fields.CustomFieldViaSettings", gen_same_text)

django.setup()
4 changes: 4 additions & 0 deletions tests/generic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class CustomFieldWithoutGenerator(models.TextField):
pass


class CustomFieldViaSettings(models.TextField):
pass


class FakeListField(models.TextField):
def to_python(self, value):
return value.split()
Expand Down
5 changes: 5 additions & 0 deletions tests/generic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .fields import (
CustomFieldWithGenerator,
CustomFieldWithoutGenerator,
CustomFieldViaSettings,
FakeListField,
CustomForeignKey,
)
Expand Down Expand Up @@ -292,6 +293,10 @@ class CustomFieldWithoutGeneratorModel(models.Model):
custom_value = CustomFieldWithoutGenerator()


class CustomFieldViaSettingsModel(models.Model):
custom_value = CustomFieldViaSettings()


class CustomForeignKeyWithGeneratorModel(models.Model):
custom_fk = CustomForeignKey(
Profile, blank=True, null=True, on_delete=models.CASCADE
Expand Down
4 changes: 4 additions & 0 deletions tests/test_filling_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def gen_char():

assert "Some value" == person.name

def test_ensure_adding_generators_via_settings_works(self):
obj = baker.make(models.CustomFieldViaSettingsModel)
assert "always the same text" == obj.custom_value


@pytest.mark.django_db
class TestFillingAutoFields:
Expand Down