Skip to content

Commit

Permalink
Changed PR to use auto loaders when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
ramon committed Dec 8, 2021
1 parent 6b4e2a0 commit 4a62789
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 70 deletions.
92 changes: 32 additions & 60 deletions src/alfasim_sdk/_internal/alfacase/alfacase_to_case.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
import enum
import inspect
from functools import lru_cache
from functools import partial
from functools import lru_cache, partial
from numbers import Number
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import TypeVar
from typing import Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

import attr
from attr.validators import instance_of
from barril.curve.curve import Curve
from barril.units import Array
from barril.units import Scalar
from barril.units import UnitDatabase
from barril.units import Array, Scalar, UnitDatabase
from strictyaml import YAML

from alfasim_sdk._internal import constants
Expand Down Expand Up @@ -56,9 +46,7 @@ def from_file(cls, file_path: Path) -> "DescriptionDocument":
"""
import strictyaml

from alfasim_sdk._internal.alfacase.case_description_attributes import (
DescriptionError,
)
from alfasim_sdk._internal.alfacase.case_description_attributes import DescriptionError
from alfasim_sdk._internal.alfacase.schema import case_description_schema

try:
Expand Down Expand Up @@ -97,7 +85,7 @@ def update_multi_input_flags(document: DescriptionDocument, item_description: T)
if (not is_set_in_alfacase) and isinstance(value, constants.MultiInputType):
assert key.endswith(constants.MULTI_INPUT_TYPE_SUFFIX)

constant_key = key[: -len(constants.MULTI_INPUT_TYPE_SUFFIX)]
constant_key = key[:-len(constants.MULTI_INPUT_TYPE_SUFFIX)]
has_constant_data = constant_key in document

curve_key = f"{constant_key}_curve"
Expand Down Expand Up @@ -137,8 +125,25 @@ def get_instance_loader(*, class_: type) -> Callable:
return partial(load_instance, class_=class_)


@lru_cache(maxsize=None)
def get_dict_of_loader(*, class_: type) -> Callable:
"""
Return a load instance function pre-populate with the class_. @ramon change the docstring
"""

def _get_dict_of_loader(alfacase_content: DescriptionDocument, class_: Type[T]) -> T:
return {
key.data: load_instance(
DescriptionDocument(value, alfacase_content.file_path), class_=class_
)
for key, value in alfacase_content.content.items()
}

return partial(_get_dict_of_loader, class_=class_)


def get_case_description_attribute_loader_dict(
class_: Any, explicit_loaders: Optional[Dict[str, Callable]] = None
class_: Any, explicit_loaders: Optional[Dict[str, Callable]]=None
) -> Dict[str, Callable]:
"""
Create a dict of loaders to be used with `to_case_values`.
Expand Down Expand Up @@ -186,7 +191,7 @@ def load_scalar(

@lru_cache(maxsize=None)
def get_scalar_loader(
*, category: Optional[str] = None, from_unit: Optional[str] = None
*, category: Optional[str]=None, from_unit: Optional[str]=None
) -> Callable:
"""
Return a LoadArray function pre-populate with the category
Expand All @@ -213,7 +218,7 @@ def load_array(key: str, alfacase_content: DescriptionDocument, category: str) -

@lru_cache(maxsize=None)
def get_array_loader(
*, category: Optional[str] = None, from_unit: Optional[str] = None
*, category: Optional[str]=None, from_unit: Optional[str]=None
) -> Callable:
"""
Return a LoadArray function pre-populate with the category
Expand Down Expand Up @@ -243,7 +248,7 @@ def load_list_of_arrays(

@lru_cache(maxsize=None)
def get_list_of_arrays_loader(
*, category: Optional[str] = None, from_unit: Optional[str] = None
*, category: Optional[str]=None, from_unit: Optional[str]=None
) -> Callable:
"""
Return a LoadListOfArrays function pre-populated with the category
Expand Down Expand Up @@ -273,7 +278,7 @@ def load_dict_of_arrays(

@lru_cache(maxsize=None)
def get_dict_of_arrays_loader(
*, category: Optional[str] = None, from_unit: Optional[str] = None
*, category: Optional[str]=None, from_unit: Optional[str]=None
) -> Callable:
"""
Return a LoadDictOfArrays function pre-populated with the category
Expand All @@ -300,7 +305,7 @@ def load_curve(key: str, alfacase_content: DescriptionDocument, category: str) -

@lru_cache(maxsize=None)
def get_curve_loader(
*, category: Optional[str] = None, from_unit: Optional[str] = None
*, category: Optional[str]=None, from_unit: Optional[str]=None
) -> Callable:
"""
Return a load_curve function pre-populated with the category
Expand Down Expand Up @@ -329,7 +334,7 @@ def load_dict_of_curves(

@lru_cache(maxsize=None)
def get_curve_dict_loader(
*, category: Optional[str] = None, from_unit: Optional[str] = None
*, category: Optional[str]=None, from_unit: Optional[str]=None
) -> Callable:
"""
Return a load_dict_of_curves function pre-populated with the category
Expand Down Expand Up @@ -422,6 +427,7 @@ def load_path(key: str, alfacase_content: DescriptionDocument) -> Path:


def load_pvt_tables(alfacase_content: DescriptionDocument) -> Dict[str, Path]:

def get_table_file(value):
"""
Value can be:
Expand Down Expand Up @@ -1492,7 +1498,7 @@ def load_annulus_description(
"pvt_model": load_value,
"top_node": load_value,
"initial_conditions": load_initial_conditions_description,
"equipment": load_annulus_equipment_description,
"equipment": partial(load_instance, class_=case_description.AnnulusEquipmentDescription),
}
case_values = to_case_values(document, alfacase_to_case_description)
item_description = case_description.AnnulusDescription(**case_values)
Expand Down Expand Up @@ -1731,26 +1737,6 @@ def load_wall_description(
]


def load_leak_equipment_description(
document: DescriptionDocument,
) -> List[case_description.LeakEquipmentDescription]:
alfacase_to_case_description = get_case_description_attribute_loader_dict(
case_description.LeakEquipmentDescription
)

def generate_leak_equipment_description(document: DescriptionDocument):
case_values = to_case_values(document, alfacase_to_case_description)
item_description = case_description.LeakEquipmentDescription(**case_values)
return update_multi_input_flags(document, item_description)

return {
key.data: generate_leak_equipment_description(
DescriptionDocument(value, document.file_path)
)
for key, value in document.content.items()
}


def load_equipment_description(
document: DescriptionDocument,
) -> case_description.EquipmentDescription:
Expand All @@ -1762,7 +1748,7 @@ def load_equipment_description(
"reservoir_inflows": load_reservoir_inflow_equipment_description,
"heat_sources": load_heat_source_equipment_description,
"compressors": load_compressor_equipment_description,
"leaks": load_leak_equipment_description,
"leaks": partial(load_instance, class_=case_description.LeakEquipmentDescription),
}
return _generate_description(
document,
Expand All @@ -1771,20 +1757,6 @@ def load_equipment_description(
)


def load_annulus_equipment_description(
document: DescriptionDocument,
) -> case_description.AnnulusEquipmentDescription:
alfacase_to_case_description = {
"gas_lift_valves": load_gas_lift_valve_equipment_description,
"leaks": load_leak_equipment_description,
}
return _generate_description(
document,
alfacase_to_case_description,
case_description.AnnulusEquipmentDescription,
)


def load_x_and_y_description(
document: DescriptionDocument,
) -> case_description.XAndYDescription:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,12 @@ def attrib_dict_of(type_: type) -> attr._make._CountingAttr:
Create a new attr attribute with validator for an atribute that is a dictionary with keys as str (to represent
the name) and the content of an instance of type_
"""
metadata = {"type": "dict_of", "class_": type_}
return attr.ib(
default=attr.Factory(dict), validator=dict_of(type_), type=Dict[str, type_]
default=attr.Factory(dict),
validator=dict_of(type_),
type=Dict[str, type_],
metadata=metadata,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,15 +419,6 @@
WALL_LAYER_DESCRIPTION = case_description.WallLayerDescription(
thickness=Scalar(25.4, "mm"), material_name="Carbon Steel", has_annulus_flow=True
)
ANNULUS_DESCRIPTION = case_description.AnnulusDescription(
has_annulus_flow=True,
pvt_model="gavea",
top_node="mass_source_node",
initial_conditions=INITIAL_CONDITIONS_DESCRIPTION,
equipment=case_description.AnnulusEquipmentDescription(
gas_lift_valves={"My gas-lift valve": GAS_LIST_VALVE_DESCRIPTION},
),
)
CASE_OUTPUT_DESCRIPTION = case_description.CaseOutputDescription(
trends=TRENDS_OUTPUT_DESCRIPTION,
trend_frequency=Scalar(0.1, "s"),
Expand Down Expand Up @@ -526,6 +517,13 @@
leaks={"LEAK": LEAK_EQUIPMENT_DESCRIPTION},
gas_lift_valves={"GAS LIFT VALVE": GAS_LIST_VALVE_DESCRIPTION},
)
ANNULUS_DESCRIPTION = case_description.AnnulusDescription(
has_annulus_flow=True,
pvt_model="gavea",
top_node="mass_source_node",
initial_conditions=INITIAL_CONDITIONS_DESCRIPTION,
equipment=ANNULUS_EQUIPMENT_DESCRIPTION,
)
ENVIRONMENT_PROPERTY_DESCRIPTION = case_description.EnvironmentPropertyDescription(
position=Scalar(1, "m"),
temperature=Scalar(1, "degC"),
Expand Down

0 comments on commit 4a62789

Please sign in to comment.