Skip to content

Commit

Permalink
Merge pull request #230 from bbguimaraes/types
Browse files Browse the repository at this point in the history
revisions: fix (some) type annotations
  • Loading branch information
jsvgoncalves authored Jan 15, 2025
2 parents baeb7c8 + bcc07a3 commit 62c0506
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 19 deletions.
6 changes: 3 additions & 3 deletions bw2data/backends/proxies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import uuid
from collections.abc import Iterable
from typing import Callable, List, Optional
from typing import Callable, List, Optional, TypeAlias

import pandas as pd

Expand Down Expand Up @@ -186,7 +186,7 @@ def to_dataframe(


class Activity(ActivityProxyBase):
ORMDataset = ActivityDataset
ORMDataset: TypeAlias = ActivityDataset

def __init__(self, document=None, **kwargs):
"""Create an `Activity` proxy object.
Expand Down Expand Up @@ -556,7 +556,7 @@ def copy(self, code: Optional[str] = None, signal: bool = True, **kwargs):


class Exchange(ExchangeProxyBase):
ORMDataset = ExchangeDataset
ORMDataset: TypeAlias = ExchangeDataset

def __init__(self, document=None, **kwargs):
"""Create an `Exchange` proxy object.
Expand Down
19 changes: 19 additions & 0 deletions bw2data/data_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
from abc import abstractmethod

from bw_processing import (
clean_datapackage_name,
Expand Down Expand Up @@ -148,6 +149,24 @@ class ProcessedDataStore(DataStore):

matrix = "unknown"

@abstractmethod
def make_searchable(self, reset: bool = False, signal: bool = True) -> bool:
pass

@abstractmethod
def make_unsearchable(self, signal: bool = False) -> bool:
pass

@abstractmethod
def delete(
self,
keep_params: bool = False,
warn: bool = True,
vacuum: bool = True,
signal: bool = True,
):
pass

def dirpath_processed(self):
return projects.dir / "processed"

Expand Down
54 changes: 38 additions & 16 deletions bw2data/revisions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import itertools
import json
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Optional, Sequence, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Optional,
Sequence,
TypeAlias,
TypeVar,
Union,
)

import deepdiff

Expand Down Expand Up @@ -29,6 +40,7 @@

if TYPE_CHECKING:
import typing
import peewee


T = TypeVar("T")
Expand Down Expand Up @@ -250,8 +262,11 @@ def database_metadata_change(cls, old: dict, new: dict) -> Union[Self, None]:

@classmethod
def generate(
cls, old: Optional[Any], new: Optional[Any], operation: Optional[str] = None
) -> Self:
cls,
old: Optional[SignaledDataset],
new: Optional[SignaledDataset],
operation: Optional[str] = None,
) -> Optional[Self]:
"""
Generates a patch object from one version of an object to another.
Expand All @@ -266,10 +281,15 @@ def generate(
if operation is not None:
return getattr(cls, operation)(old, new)

if old is None and new is None:
if old is not None:
obj_id = old.id
obj_type = old.__class__
elif new is not None:
obj_id = new.id
obj_type = new.__class__
else:
raise ValueError("Both `new` and `old` are `None`")

obj_type = new.__class__ if new is not None else old.__class__
if old is not None and new is not None:
if old.__class__ != new.__class__:
raise IncompatibleClasses(f"Can't diff {old.__class__} and {new.__class__}")
Expand All @@ -294,12 +314,7 @@ def generate(
if not diff:
return None

return cls.from_difference(
label,
old.id if old is not None else new.id,
change_type,
diff,
)
return cls.from_difference(label, obj_id, change_type, diff)


class JSONEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -345,6 +360,10 @@ class RevisionedORMProxy:
`Node` (and similar for edges).
"""

ORM_CLASS: type["peewee.Model"]
PROXY_CLASS: type["peewee.Model"]
orm_as_dict: Callable[["peewee.Model"], dict]

@classmethod
def handle(cls, revision_data: dict) -> None:
getattr(cls, revision_data["change_type"])(revision_data)
Expand Down Expand Up @@ -386,6 +405,8 @@ def create(cls, revision_data: dict) -> None:


class RevisionedParameter(RevisionedORMProxy):
KEYS: Sequence[str]

@classmethod
def _state_as_dict(cls, obj: ParameterBase) -> dict:
return {key: getattr(obj, key) for key in cls.KEYS}
Expand Down Expand Up @@ -512,10 +533,10 @@ def activity_parameter_update_formula_activity_parameter_name(cls, revision_data

class RevisionedNode(RevisionedORMProxy):
PROXY_CLASS = Activity
ORM_CLASS = Activity.ORMDataset
ORM_CLASS: TypeAlias = Activity.ORMDataset

@classmethod
def orm_as_dict(cls, orm_object: Activity.ORMDataset) -> dict:
def orm_as_dict(cls, orm_object: ORM_CLASS) -> dict:
return orm_object.data

@classmethod
Expand Down Expand Up @@ -543,10 +564,10 @@ def activity_code_change(cls, revision_data: dict) -> None:

class RevisionedEdge(RevisionedORMProxy):
PROXY_CLASS = Exchange
ORM_CLASS = Exchange.ORMDataset
ORM_CLASS: TypeAlias = Exchange.ORMDataset

@classmethod
def orm_as_dict(cls, orm_object: Exchange.ORMDataset) -> dict:
def orm_as_dict(cls, orm_object: ORM_CLASS) -> dict:
return dict_as_exchangedataset(orm_object.data)


Expand Down Expand Up @@ -578,9 +599,10 @@ def handle(cls, revision_data: dict) -> None:
ParameterizedExchange: "parameterized_exchange",
Group: "group",
}
REVISIONED_LABEL_AS_OBJECT = {
REVISIONED_LABEL_AS_OBJECT: dict[str, type[RevisionedORMProxy]] = {
"lci_node": RevisionedNode,
"lci_edge": RevisionedEdge,
# TODO separate uses
"lci_database": RevisionedDatabase,
"project_parameter": RevisionedProjectParameter,
"database_parameter": RevisionedDatabaseParameter,
Expand Down

0 comments on commit 62c0506

Please sign in to comment.