Skip to content

Commit

Permalink
Use Snowflake IDs for nodes and edges
Browse files Browse the repository at this point in the history
Co-Authored-By: João Gonçalves <jsvgoncalves@gmail.com>
  • Loading branch information
cmutel and jsvgoncalves committed Nov 8, 2024
1 parent 610f8af commit f160f82
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 65 deletions.
9 changes: 6 additions & 3 deletions bw2data/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from bw2data.backends.utils import (
check_exchange,
dict_as_activitydataset,
dict_as_exchangedataset,
_dict_as_exchangedataset,
get_csv_data_dict,
retupleize_geo_strings,
)
Expand Down Expand Up @@ -525,7 +525,7 @@ def _efficient_write_dataset(

if "output" not in exchange:
exchange["output"] = (ds["database"], ds["code"])
exchanges.append(dict_as_exchangedataset(exchange))
exchanges.append(_dict_as_exchangedataset(exchange))

# Query gets passed as INSERT INTO x VALUES ('?', '?'...)
# SQLite3 has a limit of 999 variables,
Expand All @@ -542,7 +542,7 @@ def _efficient_write_dataset(
check_activity_type(ds.get("type"))
check_activity_keys(ds)

activities.append(dict_as_activitydataset(ds))
activities.append(dict_as_activitydataset(ds, add_snowflake_id=True))

if len(activities) > 125:
ActivityDataset.insert_many(activities).execute()
Expand Down Expand Up @@ -687,6 +687,9 @@ def new_node(self, code: str = None, **kwargs):
kwargs.pop("database")
obj["database"] = self.name

if "id" in kwargs:
raise ValueError(f"`id` must be created automatically, but `id={kwargs['id']}` given.")

if code is None:
obj["code"] = uuid.uuid4().hex
else:
Expand Down
22 changes: 16 additions & 6 deletions bw2data/backends/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
check_exchange_keys,
check_exchange_type,
)
from bw2data.backends.utils import dict_as_activitydataset, dict_as_exchangedataset
from bw2data.backends.utils import dict_as_activitydataset, _dict_as_exchangedataset
from bw2data.configuration import labels
from bw2data.errors import ValidityError
from bw2data.logs import stdout_feedback_logger
Expand Down Expand Up @@ -338,6 +338,8 @@ def save(self, signal: bool = True, data_already_set: bool = False, force_insert
check_activity_keys(self)

for key, value in dict_as_activitydataset(self._data).items():
# ID value is either already in `._document` (update) or will be created by
# `SnowflakeIDBaseClass.save()`.
if key != "id":
setattr(self._document, key, value)

Expand Down Expand Up @@ -495,8 +497,10 @@ def new_edge(self, **kwargs):
"""Create a new exchange linked to this activity"""
exc = Exchange()
exc.output = self.key
for key in kwargs:
exc[key] = kwargs[key]
for key, value in kwargs.items():
if key == "id":
raise ValueError(f"`id` must be created automatically, but `id={value}` given.")
exc[key] = value
return exc

def copy(self, code: Optional[str] = None, signal: bool = True, **kwargs):
Expand All @@ -511,13 +515,19 @@ def copy(self, code: Optional[str] = None, signal: bool = True, **kwargs):
for key, value in self.items():
if key != "id":
activity[key] = value
for k, v in kwargs.items():
activity._data[k] = v
for key, value in kwargs.items():
if key == "id":
raise ValueError(f"`id` must be created automatically, but `id={value}` given.")
activity._data[key] = value
activity._data["code"] = str(code or uuid.uuid4().hex)
activity.save(signal=signal)

for exc in self.exchanges():
data = copy.deepcopy(exc._data)
if 'id' in data:
# New snowflake ID will be inserted by `.save()`; shouldn't be copied over
# or specified manually
del data['id']
data["output"] = activity.key
# Change `input` for production exchanges
if exc["input"] == exc["output"]:
Expand Down Expand Up @@ -564,7 +574,7 @@ def save(self, signal: bool = True, data_already_set: bool = False, force_insert
check_exchange_type(self._data.get("type"))
check_exchange_keys(self)

for key, value in dict_as_exchangedataset(self._data).items():
for key, value in _dict_as_exchangedataset(self._data).items():
setattr(self._document, key, value)

self._document.save(signal=signal, force_insert=force_insert)
Expand Down
22 changes: 19 additions & 3 deletions bw2data/backends/schema.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
from peewee import DoesNotExist, TextField
from peewee import DoesNotExist, Model, TextField, IntegerField

from bw2data.errors import UnknownObject
from bw2data.signals import SignaledDataset
from bw2data.sqlite import PickleField
from bw2data.snowflake_ids import snowflake_id_generator


class ActivityDataset(SignaledDataset):
class SnowflakeIDBaseClass(SignaledDataset):
id = IntegerField(primary_key=True)

def save(self, **kwargs):
if self.id is None:
# If the primary key column data is already present (even if the object doesn't exist in
# the database), peewee will make an `UPDATE` query. This will have no effect if there
# isn't a matching row. Need for force an `INSERT` query instead as we generate the ids
# ourselves.
# https://docs.peewee-orm.com/en/latest/peewee/models.html#id4
self.id = next(snowflake_id_generator)
kwargs['force_insert'] = True
super().save(**kwargs)


class ActivityDataset(SnowflakeIDBaseClass):
data = PickleField() # Canonical, except for other C fields
code = TextField() # Canonical
database = TextField() # Canonical
Expand All @@ -19,7 +35,7 @@ def key(self):
return (self.database, self.code)


class ExchangeDataset(SignaledDataset):
class ExchangeDataset(SnowflakeIDBaseClass):
data = PickleField() # Canonical, except for other C fields
input_code = TextField() # Canonical
input_database = TextField() # Canonical
Expand Down
13 changes: 10 additions & 3 deletions bw2data/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
import copy
import warnings
from typing import Optional
Expand All @@ -9,6 +10,7 @@
from bw2data.configuration import labels
from bw2data.errors import InvalidExchange, UntypedExchange
from bw2data.meta import databases, methods
from bw2data.snowflake_ids import snowflake_id_generator


def get_csv_data_dict(ds):
Expand Down Expand Up @@ -66,8 +68,8 @@ def check_exchange(exc):
raise ValueError("Invalid amount in exchange {}".format(exc))


def dict_as_activitydataset(ds):
return {
def dict_as_activitydataset(ds: Any, add_snowflake_id: bool = False) -> dict:
val = {
"data": ds,
"database": ds["database"],
"code": ds["code"],
Expand All @@ -76,9 +78,14 @@ def dict_as_activitydataset(ds):
"product": ds.get("reference product"),
"type": ds.get("type", labels.process_node_default),
}
# Use during `insert_many` calls as these skip auto id generation because they don't call
# `.save()`
if add_snowflake_id:
val['id'] = next(snowflake_id_generator)
return val


def dict_as_exchangedataset(ds):
def _dict_as_exchangedataset(ds: Any) -> dict:
return {
"data": ds,
"input_database": ds["input"][0],
Expand Down
20 changes: 20 additions & 0 deletions bw2data/snowflake_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from snowflake import SnowflakeGenerator
import uuid

# Jan 1, 2024
# from datetime import datetime
# (datetime(2024, 1, 1) - datetime.utcfromtimestamp(0)).total_seconds() * 1000.0
EPOCH_START_MS = 1704067200000

# From https://softwaremind.com/blog/the-unique-features-of-snowflake-id-and-its-comparison-to-uuid/
# Snowflake bits:
# Sign bit: 1 bit. It will always be 0. This is reserved for future uses. It can potentially be used
# to distinguish between signed and unsigned numbers.
# Timestamp: 41 bits. Milliseconds since the epoch or custom epoch.
# Datacenter ID: 5 bits, which gives us 2 ^ 5 = 32 datacenters.
# Machine ID: 5 bits, which gives us 2 ^ 5 = 32 machines per datacenter.
# However, `snowflake-id` lumps the two datacenter and machine id values together into an
# `instance` parameter with 2 ^ 10 = 1024 possible values.
# Sequence number: 12 bits. For every ID generated on that machine/process, the sequence number is
# incremented by 1. The number is reset to 0 every millisecond.
snowflake_id_generator = SnowflakeGenerator(instance=uuid.getnode() % 1024, epoch=EPOCH_START_MS)
7 changes: 7 additions & 0 deletions tests/activity_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def test_copy(activity):
assert cp["name"] == "baz"
assert cp["location"] == "bar"
assert ExchangeDataset.select().count() == 2

cp.save()

assert ActivityDataset.select().count() == 2
assert (
ActivityDataset.select()
Expand Down Expand Up @@ -241,6 +244,8 @@ def test_delete_activity_parameters():
b.save()
a.new_exchange(amount=0, input=b, type="technosphere", formula="foo * bar + 4").save()

assert ExchangeDataset.select().count() == 1

activity_data = [
{
"name": "reference_me",
Expand All @@ -258,6 +263,8 @@ def test_delete_activity_parameters():
parameters.new_activity_parameters(activity_data, "my group")
parameters.add_exchanges_to_group("my group", a)

assert ExchangeDataset.select().count() == 1

assert ActivityParameter.select().count() == 2
assert ParameterizedExchange.select().count() == 1

Expand Down
39 changes: 23 additions & 16 deletions tests/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
databases,
geomapping,
get_activity,
get_node,
get_multilca_data_objs,
get_node,
methods,
Expand Down Expand Up @@ -65,27 +66,32 @@ def setup():
def test_prepare_lca_inputs_basic(setup):
d, objs, r = prepare_lca_inputs(demand={("food", "1"): 1}, method=("foo",))
# ID is 3; two biosphere flows, then '1' is next written
assert d == {3: 1}
assert list(d.values()) == [1]
assert {o.metadata["id"] for o in objs} == {o.datapackage().metadata["id"] for o in setup}

b1 = get_node(database="biosphere", code='1').id
b2 = get_node(database="biosphere", code='2').id
f1 = get_node(database="food", code='1').id
f2 = get_node(database="food", code='2').id

remapping_expected = {
"activity": {
1: ("biosphere", "1"),
2: ("biosphere", "2"),
3: ("food", "1"),
4: ("food", "2"),
b1: ("biosphere", "1"),
b2: ("biosphere", "2"),
f1: ("food", "1"),
f2: ("food", "2"),
},
"product": {
1: ("biosphere", "1"),
2: ("biosphere", "2"),
3: ("food", "1"),
4: ("food", "2"),
b1: ("biosphere", "1"),
b2: ("biosphere", "2"),
f1: ("food", "1"),
f2: ("food", "2"),
},
"biosphere": {
1: ("biosphere", "1"),
2: ("biosphere", "2"),
3: ("food", "1"),
4: ("food", "2"),
b1: ("biosphere", "1"),
b2: ("biosphere", "2"),
f1: ("food", "1"),
f2: ("food", "2"),
},
}
assert r == remapping_expected
Expand All @@ -104,16 +110,17 @@ def test_prepare_lca_inputs_multiple_demands_data_types(setup):
first = get_node(database="food", code="1")
second = get_node(database="food", code="2")
d, objs, r = prepare_lca_inputs(demands=[{first: 1}, {second.id: 10}], method=("foo",))
assert d == [{3: 1}, {4: 10}]
assert d == [{first.id: 1}, {second.id: 10}]
assert {o.metadata["id"] for o in objs} == {o.datapackage().metadata["id"] for o in setup}


def test_prepare_lca_inputs_multiple_demands(setup):
d, objs, r = prepare_lca_inputs(
demands=[{("food", "1"): 1}, {("food", "2"): 10}], method=("foo",)
)
# ID is 3; two biosphere flows, then '1' is next written
assert d == [{3: 1}, {4: 10}]
f1 = get_node(database="food", code='1').id
f2 = get_node(database="food", code='2').id
assert d == [{f1: 1}, {f2: 10}]
assert {o.metadata["id"] for o in objs} == {o.datapackage().metadata["id"] for o in setup}


Expand Down
11 changes: 5 additions & 6 deletions tests/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from bw2data.backends import Activity as PWActivity
from bw2data.backends import sqlite3_lci_db
from bw2data.database import Database
from bw2data.snowflake_ids import EPOCH_START_MS
from bw2data.errors import (
DuplicateNode,
InvalidExchange,
Expand Down Expand Up @@ -63,7 +64,7 @@ def test_get_code():
activity = d.get("1")
assert isinstance(activity, PWActivity)
assert activity["name"] == "an emission"
assert activity.id == 1
assert activity.id > EPOCH_START_MS


@bw2test
Expand All @@ -73,7 +74,7 @@ def test_get_kwargs():
activity = d.get(name="an emission")
assert isinstance(activity, PWActivity)
assert activity["name"] == "an emission"
assert activity.id == 1
assert activity.id > EPOCH_START_MS


@bw2test
Expand Down Expand Up @@ -541,7 +542,7 @@ def test_processed_array_with_metadata():
"reference product": np.NaN,
"unit": "something",
"location": np.NaN,
"id": 1,
"id": df.id[0],
}
]
)
Expand Down Expand Up @@ -716,7 +717,7 @@ def test_process_without_exchanges_still_in_processed_array():

package = database.datapackage()
array = package.get_resource("a_database_technosphere_matrix.data")[0]
assert array[0] == 1
assert array[0] == 1.
assert array.shape == (1,)


Expand Down Expand Up @@ -774,8 +775,6 @@ def test_new_node_error():

with pytest.raises(DuplicateNode):
database.new_node("foo")
with pytest.raises(DuplicateNode):
database.new_node(code="bar", id=act.id)


@bw2test
Expand Down
16 changes: 12 additions & 4 deletions tests/ia.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,11 @@ def test_method_geocollection():
}
)

f1 = get_node(code="1").id
f2 = get_node(code="2").id

m = Method(("foo",))
m.write([(1, 2, "RU"), (3, 4, ("foo", "bar"))])
m.write([(f1, 2, "RU"), (f2, 4, ("foo", "bar"))])
assert m.metadata["geocollections"] == ["foo", "world"]


Expand All @@ -294,11 +297,14 @@ def test_method_geocollection_missing_ok():
}
)

f1 = get_node(code="1").id
f3 = get_node(code="3").id

m = Method(("foo",))
m.write(
[
(1, 2, None),
(3, 4),
(f1, 2, None),
(f3, 4),
]
)
assert m.metadata["geocollections"] == ["world"]
Expand All @@ -313,10 +319,12 @@ def test_method_geocollection_warning():
}
)

f1 = get_node(code="1").id

m = Method(("foo",))
m.write(
[
(1, 2, "Russia"),
(f1, 2, "Russia"),
]
)
assert m.metadata["geocollections"] == []
Expand Down
Loading

0 comments on commit f160f82

Please sign in to comment.