Skip to content

Commit

Permalink
Store cached model defaults in self._defaults, avoid sharing referenc…
Browse files Browse the repository at this point in the history
…es to mutable defaults
  • Loading branch information
dmach committed Jan 3, 2024
1 parent 587c094 commit 16cdc06
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 43 deletions.
34 changes: 21 additions & 13 deletions osc/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
This module IS NOT a supported API, it is meant for osc internal use only.
"""

import copy
import inspect
import sys
import types
Expand Down Expand Up @@ -82,9 +83,6 @@ def __init__(
# a flag indicating, whether the default is a callable with lazy evalution
self.default_is_lazy = callable(self.default)

# whether the field was set
self.is_set = False

# the name of model's attribute associated with this field instance - set from the model
self.name = None

Expand Down Expand Up @@ -209,6 +207,11 @@ def get(self, obj):
except KeyError:
pass

try:
return obj._defaults[self.name]
except KeyError:
pass

if isinstance(self.default, FromParent):
if obj._parent is None:
raise RuntimeError(f"The field '{self.name}' has default {self.default} but the model has no parent set")
Expand All @@ -217,18 +220,23 @@ def get(self, obj):
if self.default is NotSet:
raise RuntimeError(f"The field '{self.name}' has no default")

# make a deepcopy to avoid problems with mutable defaults
default = copy.deepcopy(self.default)

# lazy evaluation of a factory function on first use
if callable(self.default):
self.default = self.default()
if callable(default):
default = default()

# if this is a model field, convert dict to a model instance
if self.is_model and isinstance(self.default, dict):
new_value = self.origin_type() # pylint: disable=not-callable
for k, v in self.default.items():
if self.is_model and isinstance(default, dict):
cls = self.origin_type
new_value = cls() # pylint: disable=not-callable
for k, v in default.items():
setattr(new_value, k, v)
self.default = new_value
default = new_value

return self.default
obj._defaults[self.name] = default
return default

def set(self, obj, value):
# if this is a model field, convert dict to a model instance
Expand All @@ -240,7 +248,6 @@ def set(self, obj, value):

self.validate_type(value)
obj._values[self.name] = value
self.is_set = True


class ModelMeta(type):
Expand Down Expand Up @@ -288,7 +295,8 @@ def __setattr__(self, name, value):
raise AttributeError(f"Setting attribute '{self.__class__.__name__}.{name}' is not allowed")

def __init__(self, **kwargs):
self._values = {}
self._defaults = {} # field defaults cached in field.get()
self._values = {} # field values explicitly set after initializing the model
self._parent = kwargs.pop("_parent", None)

uninitialized_fields = []
Expand Down Expand Up @@ -321,7 +329,7 @@ def dict(self, exclude_unset=False):
for name, field in self.__fields__.items():
if field.exclude:
continue
if exclude_unset and not field.is_set and field.is_optional:
if exclude_unset and field.name not in self._values and field.is_optional:
# include only mandatory fields and optional fields that were set to an actual value
continue
if field.is_model:
Expand Down
8 changes: 1 addition & 7 deletions tests/test_build.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import unittest

import osc.conf
Expand All @@ -8,12 +7,7 @@

class TestTrustedProjects(unittest.TestCase):
def setUp(self):
# reset the global `config` in preparation for running the tests
importlib.reload(osc.conf)

def tearDown(self):
# reset the global `config` to avoid impacting tests from other classes
importlib.reload(osc.conf)
osc.conf.config = osc.conf.Options()

def test_name(self):
apiurl = "https://example.com"
Expand Down
11 changes: 9 additions & 2 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import os
import shutil
import tempfile
Expand Down Expand Up @@ -105,7 +104,6 @@

class TestExampleConfig(unittest.TestCase):
def setUp(self):
importlib.reload(osc.conf)
self.tmpdir = tempfile.mkdtemp(prefix="osc_test_")
self.oscrc = os.path.join(self.tmpdir, "oscrc")
with open(self.oscrc, "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -481,6 +479,15 @@ def test_write_initial_config(self):
}
osc.conf.write_initial_config(conffile, entries)

def test_api_host_options(self):
# test that instances do not share any references leaked from the defaults
conf1 = osc.conf.Options()
conf2 = osc.conf.Options()

self.assertNotEqual(conf1, conf2)
self.assertNotEqual(id(conf1), id(conf2))
self.assertNotEqual(id(conf1.api_host_options), id(conf2.api_host_options))


if __name__ == "__main__":
unittest.main()
5 changes: 0 additions & 5 deletions tests/test_grabber.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import os
import tempfile
import unittest
Expand All @@ -13,14 +12,10 @@
class TestMirrorGroup(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp(prefix='osc_test')
# reset the global `config` in preparation for running the tests
importlib.reload(osc.conf)
oscrc = os.path.join(self._get_fixtures_dir(), "oscrc")
osc.conf.get_config(override_conffile=oscrc, override_no_keyring=True)

def tearDown(self):
# reset the global `config` to avoid impacting tests from other classes
importlib.reload(osc.conf)
try:
shutil.rmtree(self.tmpdir)
except:
Expand Down
9 changes: 4 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ class TestModel(BaseModel):

m = TestModel()

field = m.__fields__["field"]
self.assertEqual(field.is_set, False)
self.assertNotIn("field", m._values)
self.assertEqual(m.field, None)

m.field = "text"
self.assertEqual(field.is_set, True)

self.assertIn("field", m._values)
self.assertEqual(m.field, "text")

def test_str(self):
Expand All @@ -95,7 +96,6 @@ class TestModel(BaseModel):
field = m.__fields__["field"]
self.assertEqual(field.is_model, False)
self.assertEqual(field.is_optional, False)
self.assertEqual(field.is_set, False)
self.assertEqual(field.origin_type, str)

self.assertEqual(m.field, "default")
Expand All @@ -111,7 +111,6 @@ class TestModel(BaseModel):
field = m.__fields__["field"]
self.assertEqual(field.is_model, False)
self.assertEqual(field.is_optional, True)
self.assertEqual(field.is_set, False)
self.assertEqual(field.origin_type, str)

self.assertEqual(m.field, None)
Expand Down
8 changes: 1 addition & 7 deletions tests/test_output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import contextlib
import importlib
import io
import unittest

Expand Down Expand Up @@ -74,12 +73,7 @@ def test_wide_chars(self):

class TestPrintMsg(unittest.TestCase):
def setUp(self):
# reset the global `config` in preparation for running the tests
importlib.reload(osc.conf)

def tearDown(self):
# reset the global `config` to avoid impacting tests from other classes
importlib.reload(osc.conf)
osc.conf.config = osc.conf.Options()

def test_debug(self):
osc.conf.config["debug"] = False
Expand Down
5 changes: 1 addition & 4 deletions tests/test_vc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import os
import unittest

Expand All @@ -11,7 +10,7 @@

class TestVC(unittest.TestCase):
def setUp(self):
importlib.reload(osc.conf)
osc.conf.config = osc.conf.Options()

config = osc.conf.config
host_options = osc.conf.HostOptions(
Expand All @@ -21,8 +20,6 @@ def setUp(self):
config["apiurl"] = host_options["apiurl"]
self.host_options = host_options

def tearDown(self):
importlib.reload(osc.conf)

@patch.dict(os.environ, {}, clear=True)
def test_vc_export_env_conf(self):
Expand Down

0 comments on commit 16cdc06

Please sign in to comment.