Skip to content

Commit

Permalink
Fix ORM class inheritance hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
cmutel authored and jsvgoncalves committed Nov 8, 2024
1 parent 977f78d commit 37cda76
Show file tree
Hide file tree
Showing 15 changed files with 106 additions and 92 deletions.
4 changes: 3 additions & 1 deletion bw2data/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,9 @@ def _add_inventory_geomapping_to_datapackage(self, dp: Datapackage) -> None:
dict_iterator=(
{
"row": row[0],
"col": geomapping[location_mapper(retupleize_geo_strings(row[1]) or config.global_location)],
"col": geomapping[
location_mapper(retupleize_geo_strings(row[1]) or config.global_location)
],
"amount": 1,
}
for row in inv_mapping_qs.tuples()
Expand Down
4 changes: 2 additions & 2 deletions bw2data/backends/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,10 @@ def copy(self, code: Optional[str] = None, signal: bool = True, **kwargs):

for exc in self.exchanges():
data = copy.deepcopy(exc._data)
if 'id' in data:
if "id" in data:
# New snowflake ID will be inserted by `.save()`; shouldn't be copied over
# or specified manually
del data['id']
del data["id"]
data["output"] = activity.key
# Change `input` for production exchanges
if exc["input"] == exc["output"]:
Expand Down
20 changes: 2 additions & 18 deletions bw2data/backends/schema.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,8 @@
from peewee import DoesNotExist, Model, TextField, IntegerField
from peewee import DoesNotExist, TextField

from bw2data.errors import UnknownObject
from bw2data.signals import SignaledDataset
from bw2data.snowflake_ids import SnowflakeIDBaseClass
from bw2data.sqlite import PickleField
from bw2data.snowflake_ids import snowflake_id_generator


class SnowflakeIDBaseClass(SignaledDataset):
id = IntegerField(primary_key=True)

def save(self, **kwargs):
if self.id is None:
# If the primary key column data is already present (even if the object doesn't exist in
# the database), peewee will make an `UPDATE` query. This will have no effect if there
# isn't a matching row. Need for force an `INSERT` query instead as we generate the ids
# ourselves.
# https://docs.peewee-orm.com/en/latest/peewee/models.html#id4
self.id = next(snowflake_id_generator)
kwargs['force_insert'] = True
super().save(**kwargs)


class ActivityDataset(SnowflakeIDBaseClass):
Expand Down
8 changes: 4 additions & 4 deletions bw2data/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Any
import copy
import warnings
from typing import Optional
from typing import Any, Optional

import numpy as np

from bw2data import config
from bw2data.backends.schema import SignaledDataset, get_id
from bw2data.backends.schema import get_id
from bw2data.configuration import labels
from bw2data.errors import InvalidExchange, UntypedExchange
from bw2data.meta import databases, methods
from bw2data.signals import SignaledDataset
from bw2data.snowflake_ids import snowflake_id_generator


Expand Down Expand Up @@ -81,7 +81,7 @@ def dict_as_activitydataset(ds: Any, add_snowflake_id: bool = False) -> dict:
# Use during `insert_many` calls as these skip auto id generation because they don't call
# `.save()`
if add_snowflake_id:
val['id'] = next(snowflake_id_generator)
val["id"] = next(snowflake_id_generator)
return val


Expand Down
4 changes: 1 addition & 3 deletions bw2data/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import copy
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Sequence, TypeVar
from typing import TYPE_CHECKING, Any, Optional, Sequence

import wrapt
from bw_processing import safe_filename
Expand All @@ -25,8 +25,6 @@

if TYPE_CHECKING:
from bw2data import revisions
from bw2data.backends import schema
SD = TypeVar("SD", bound=schema.SignaledDataset)


READ_ONLY_PROJECT = """
Expand Down
8 changes: 3 additions & 5 deletions bw2data/revisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import deepdiff

from bw2data.snowflake_ids import snowflake_id_generator
from bw2data.backends.proxies import Activity, Exchange
from bw2data.backends.schema import ActivityDataset, ExchangeDataset, SignaledDataset
from bw2data.backends.schema import ActivityDataset, ExchangeDataset
from bw2data.backends.utils import dict_as_activitydataset, dict_as_exchangedataset
from bw2data.errors import DifferentObjects, IncompatibleClasses, InconsistentData
from bw2data.signals import SignaledDataset
from bw2data.snowflake_ids import snowflake_id_generator
from bw2data.utils import get_node

try:
Expand All @@ -16,9 +17,6 @@
from typing_extensions import Self


SD = TypeVar("SD", bound=SignaledDataset)


class RevisionGraph:
"""Graph of revisions, edges are based on `metadata.parent_revision`."""

Expand Down
21 changes: 20 additions & 1 deletion bw2data/snowflake_ids.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from snowflake import SnowflakeGenerator
import uuid

from peewee import IntegerField
from snowflake import SnowflakeGenerator

from bw2data.signals import SignaledDataset

# Jan 1, 2024
# from datetime import datetime
# (datetime(2024, 1, 1) - datetime.utcfromtimestamp(0)).total_seconds() * 1000.0
Expand All @@ -18,3 +22,18 @@
# Sequence number: 12 bits. For every ID generated on that machine/process, the sequence number is
# incremented by 1. The number is reset to 0 every millisecond.
snowflake_id_generator = SnowflakeGenerator(instance=uuid.getnode() % 1024, epoch=EPOCH_START_MS)


class SnowflakeIDBaseClass(SignaledDataset):
id = IntegerField(primary_key=True)

def save(self, **kwargs):
if self.id is None:
# If the primary key column data is already present (even if the object doesn't exist in
# the database), peewee will make an `UPDATE` query. This will have no effect if there
# isn't a matching row. Need for force an `INSERT` query instead as we generate the ids
# ourselves.
# https://docs.peewee-orm.com/en/latest/peewee/models.html#id4
self.id = next(snowflake_id_generator)
kwargs["force_insert"] = True
super().save(**kwargs)
13 changes: 6 additions & 7 deletions tests/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
databases,
geomapping,
get_activity,
get_node,
get_multilca_data_objs,
get_node,
methods,
Expand Down Expand Up @@ -69,10 +68,10 @@ def test_prepare_lca_inputs_basic(setup):
assert list(d.values()) == [1]
assert {o.metadata["id"] for o in objs} == {o.datapackage().metadata["id"] for o in setup}

b1 = get_node(database="biosphere", code='1').id
b2 = get_node(database="biosphere", code='2').id
f1 = get_node(database="food", code='1').id
f2 = get_node(database="food", code='2').id
b1 = get_node(database="biosphere", code="1").id
b2 = get_node(database="biosphere", code="2").id
f1 = get_node(database="food", code="1").id
f2 = get_node(database="food", code="2").id

remapping_expected = {
"activity": {
Expand Down Expand Up @@ -118,8 +117,8 @@ def test_prepare_lca_inputs_multiple_demands(setup):
d, objs, r = prepare_lca_inputs(
demands=[{("food", "1"): 1}, {("food", "2"): 10}], method=("foo",)
)
f1 = get_node(database="food", code='1').id
f2 = get_node(database="food", code='2').id
f1 = get_node(database="food", code="1").id
f2 = get_node(database="food", code="2").id
assert d == [{f1: 1}, {f2: 10}]
assert {o.metadata["id"] for o in objs} == {o.datapackage().metadata["id"] for o in setup}

Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Fixtures for bw2data"""

import sqlite3

# import pytest
Expand Down
12 changes: 6 additions & 6 deletions tests/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
calculation_setups,
databases,
geomapping,
projects,
get_activity,
get_id,
get_node,
projects,
)
from bw2data.backends import Activity as PWActivity
from bw2data.backends import sqlite3_lci_db
from bw2data.database import Database
from bw2data.snowflake_ids import EPOCH_START_MS
from bw2data.errors import (
DuplicateNode,
InvalidExchange,
Expand All @@ -33,6 +32,7 @@
ParameterizedExchange,
parameters,
)
from bw2data.snowflake_ids import EPOCH_START_MS
from bw2data.tests import bw2test

from .fixtures import biosphere
Expand Down Expand Up @@ -105,10 +105,10 @@ def test_copy(food):

def test_copy_metadata(food):
d = Database("food")
d.metadata['custom'] = "something"
d.metadata["custom"] = "something"
d.copy("repas")
assert "repas" in databases
assert databases['repas']['custom'] == 'something'
assert databases["repas"]["custom"] == "something"


@bw2test
Expand Down Expand Up @@ -458,7 +458,7 @@ def test_geomapping_array_includes_only_processes():
@bw2test
def test_geomapping_array_normalization():
database = Database("a database")
database.register(location_normalization={'RoW': 'GLO'})
database.register(location_normalization={"RoW": "GLO"})
database.write(
{
("a database", "foo"): {
Expand Down Expand Up @@ -717,7 +717,7 @@ def test_process_without_exchanges_still_in_processed_array():

package = database.datapackage()
array = package.get_resource("a_database_technosphere_matrix.data")[0]
assert array[0] == 1.
assert array[0] == 1.0
assert array.shape == (1,)


Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .basic import get_naughty, food2, food, biosphere, lcia
from .basic import biosphere, food, food2, get_naughty, lcia
19 changes: 10 additions & 9 deletions tests/test_schema_migrations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pathlib import Path
from bw2data.project import add_sourced_columns, projects, ProjectDataset, config
from bw2data.tests import bw2test
import shutil
from pathlib import Path

from peewee import SqliteDatabase

from bw2data.project import ProjectDataset, add_sourced_columns, config, projects
from bw2data.tests import bw2test

original_projects_db = Path(__file__).parent / "fixtures" / "projects.db"

Expand All @@ -27,13 +28,13 @@ def test_add_sourced_columns(tmp_path):

columns = {o.name: o for o in db.get_columns("projectdataset")}
assert "is_sourced" in columns
assert columns['is_sourced'].data_type.upper() == 'INTEGER'
assert columns['is_sourced'].default == '0'
assert columns['is_sourced'].null is True
assert columns["is_sourced"].data_type.upper() == "INTEGER"
assert columns["is_sourced"].default == "0"
assert columns["is_sourced"].null is True
assert "revision" in columns
assert columns['revision'].data_type.upper() == 'INTEGER'
assert columns['revision'].default is None
assert columns['revision'].null is True
assert columns["revision"].data_type.upper() == "INTEGER"
assert columns["revision"].default is None
assert columns["revision"].null is True

db = SqliteDatabase(tmp_path / "projects.backup-is-sourced.db")
db.connect()
Expand Down
Loading

0 comments on commit 37cda76

Please sign in to comment.