diff --git a/ogs6py/classes/curve.py b/ogs6py/classes/curve.py new file mode 100644 index 0000000..418ce11 --- /dev/null +++ b/ogs6py/classes/curve.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) 2012-2023, OpenGeoSys Community (http://www.opengeosys.org) + Distributed under a Modified BSD License. + See accompanying file LICENSE or + http://www.opengeosys.org/project/license + +""" +from fastcore.utils import * +import numpy as np +from lxml import etree as ET +# pylint: disable=C0103, R0902, R0914, R0913 + +class Curve: + def __init__(self, xmlobject:object=None) -> None: + self.__dict__ = {} + self.xmlobject = xmlobject + if not self.xmlobject is None: + for curve_property in self.xmlobject: + if curve_property.tag == "name": + self.__dict__[curve_property.tag] = curve_property.text + else: + self.__dict__[curve_property.tag] = curve_property.text.split(" ") + assert(len(self.__dict__["coords"])==len(self.__dict__["values"])) + + def __setitem__(self, key, item): + if not key in self.__dict__: + raise RuntimeError("property is not existing") + expression_counter = -1 + if key in ["coords", "values"]: + for i, entry in enumerate(item): + item[i] = str(entry) + for curve_property in self.xmlobject: + if curve_property.tag == key: + curve_property.text = ' '.join(item) + else: + for curve_property in self.xmlobject: + if curve_property.tag == key: + curve_property.text = item + + def __getitem__(self, key): + if not (key == "xmlobject"): + return self.__dict__[key] + + def __repr__(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k == "xmlobject"): + newdict[k] = v + return repr(newdict) + + def __len__(self): + return len(self.__dict__) + + def __delitem__(self, key): + pass + #del self.__dict__[key] + + def clear(self): + return self.__dict__.clear() + + def copy(self): + return self.__dict__.copy() + + def has_key(self, k): + if not (k == "xmlobject"): + return k in self.__dict__ + + def update(self, *args, **kwargs): + pass + # return self.__dict__.update(*args, **kwargs) + + def keys(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k == "xmlobject"): + newdict[k] = v + return newdict.keys() + + def items(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k == "xmlobject"): + newdict[k] = v + return newdict.items() + + def pop(self, *args): + pass + #return self.__dict__.pop(*args) + + def __cmp__(self, dict_): + return self.__cmp__(self.__dict__, dict_) + + def __contains__(self, item): + newdict = {} + for k, v in self.__dict__.items(): + if not (k == "xmlobject"): + newdict[k] = v + return item in newdict + + def __iter__(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k == "xmlobject"): + newdict[k] = v + return iter(newdict) + + def evaluate_values(self,x): + coords = np.array([float(val) for val in self.__dict__["coords"]]) + values = np.array([float(val) for val in self.__dict__["values"]]) + return np.interp(x, coords, values) diff --git a/ogs6py/classes/curves.py b/ogs6py/classes/curves.py index 287d8de..567d948 100644 --- a/ogs6py/classes/curves.py +++ b/ogs6py/classes/curves.py @@ -1,19 +1,21 @@ # -*- coding: utf-8 -*- """ -Copyright (c) 2012-2021, OpenGeoSys Community (http://www.opengeosys.org) +Copyright (c) 2012-2023, OpenGeoSys Community (http://www.opengeosys.org) Distributed under a Modified BSD License. See accompanying file LICENSE or http://www.opengeosys.org/project/license """ # pylint: disable=C0103, R0902, R0914, R0913 +from lxml import etree as ET from ogs6py.classes import build_tree +from ogs6py.classes import curve class Curves(build_tree.BuildTree): """ Class to create the curve section of the project file. """ - def __init__(self): + def __init__(self, xmlobject=None): self.tree = { 'curves': { 'tag': 'curves', @@ -22,6 +24,121 @@ def __init__(self): 'children': {} } } + self.xmlobject = xmlobject + if not (xmlobject is None): + for curve_obj in xmlobject: + for curve_property in curve_obj: + if curve_property.tag == "name": + curve_name = curve_property.text + self.__dict__[curve_name] = curve.Curve(curve_obj) + + def __checkcurve(self, dictionary): + required = ["coords", "values"] + optional = ["name"] + for k, v in dictionary.items(): + if not k in (required+optional): + raise RuntimeError(f"{k} is not a valid property field for a curve.") + for entry in required: + if not entry in dictionary: + raise RuntimeError(f"{entry} is required for creating a curve.") + + def __setitem__(self, key, item): + if not isinstance(item, dict): + raise RuntimeError("Item must be a dictionary") + if len(item) == 0: + self.__delitem__(key) + return + self.__checkcurve(item) + if key in self.__dict__: + self.__delitem__(key) + curve_obj = ET.SubElement(self.xmlobject, "curve") + assert(len(item["coords"])==len(item["values"])) + for i, entry in enumerate(item["coords"]): + item["coords"][i] = str(entry) + item["values"][i] = str(item["values"][i]) + q = ET.SubElement(curve_obj, "name") + q.text = key + q = ET.SubElement(curve_obj, "coords") + q.text = ' '.join(item["coords"]) + q = ET.SubElement(curve_obj, "values") + q.text = ' '.join(item["values"]) + return curve_obj + + + def __getitem__(self, key): + if not (key in ["tree", "xmlobject"]): + return self.__dict__[key] + + def __repr__(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "name", "xmlobject"]): + newdict[k] = v + return repr(newdict) + + def __len__(self): + return len(self.__dict__) + + def __delitem__(self, key): + obj = self.__dict__[key].xmlobject + obj.getparent().remove(obj) + del self.__dict__[key] + + def clear(self): + return self.__dict__.clear() + + def copy(self): + return self.__dict__.copy() + + def has_key(self, k): + if not (k in ["tree","xmlobject"]): + return k in self.__dict__ + + def update(self, *args, **kwargs): + pass + # return self.__dict__.update(*args, **kwargs) + + def keys(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree","xmlobject"]): + newdict[k] = v + return newdict.keys() + + def values(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "xmlobject"]): + newdict[k] = v + return newdict.values() + + def items(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "xmlobject"]): + newdict[k] = v + return newdict.items() + + def pop(self, *args): + pass + #return self.__dict__.pop(*args) + + def __cmp__(self, dict_): + return self.__cmp__(self.__dict__, dict_) + + def __contains__(self, item): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "xmlobject"]): + newdict[k] = v + return item in newdict + + def __iter__(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "xmlobject"]): + newdict[k] = v + return iter(newdict) def add_curve(self, **args): """ diff --git a/ogs6py/classes/local_coordinate_system.py b/ogs6py/classes/local_coordinate_system.py index e82e15d..da56a7b 100644 --- a/ogs6py/classes/local_coordinate_system.py +++ b/ogs6py/classes/local_coordinate_system.py @@ -6,6 +6,7 @@ http://www.opengeosys.org/project/license """ +import numpy as np # pylint: disable=C0103, R0902, R0914, R0913 from ogs6py.classes import build_tree @@ -13,7 +14,7 @@ class LocalCoordinateSystem(build_tree.BuildTree): """ Class for defining a local coordinate system in the project file" """ - def __init__(self): + def __init__(self, xmlobject=None): self.tree = { 'local_coordinate_system': { 'tag': 'local_coordinate_system', @@ -22,6 +23,21 @@ def __init__(self): 'children': {} } } + self.xmlobject = xmlobject + self.R = None + if not self.xmlobject is None: + basis_vectors = self.xmlobject.getchildren() + dim = len(basis_vectors) + self.R = np.zeros((dim,dim)) + basis_vector_names = [] + for vec in basis_vectors: + basis_vector_names.append(vec.text) + basis_vector_values = [] + for vec in basis_vector_names: + basis_vector_values.append(np.fromstring(self.xmlobject.getparent().find(f"./parameters/parameter[name='{vec}']/values").text, sep=' ')) + for i in range(dim): + for j in range(dim): + self.R[i,j] = basis_vector_values[i][j] def add_basis_vec(self, **args): """ diff --git a/ogs6py/classes/parameter_type.py b/ogs6py/classes/parameter_type.py new file mode 100644 index 0000000..9a6d469 --- /dev/null +++ b/ogs6py/classes/parameter_type.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) 2012-2023, OpenGeoSys Community (http://www.opengeosys.org) + Distributed under a Modified BSD License. + See accompanying file LICENSE or + http://www.opengeosys.org/project/license + +""" +from fastcore.utils import * +import numpy as np +import ctypes +from lxml import etree as ET +try: + import vtuIO + has_vtuinterface = True +except ImportError: + has_vtuinterface = False +try: + import cexprtk + has_cexprtk = True +except ImportError: + has_cexprtk = False + +# pylint: disable=C0103, R0902, R0914, R0913 + +class Parameter_type: + def __init__(self, xmlobject:object=None, paramobject=None, curvesobj=None, trafo_matrix=None) -> None: + self.__dict__ = {} + self.xmlobject = xmlobject + self.paramobject = id(paramobject) + self.curvesobj = curvesobj + self.trafo_matrix = trafo_matrix + if not self.xmlobject is None: + for parameter_property in self.xmlobject: + if parameter_property.tag == "expression": + try: + self.__dict__["expression"].append(parameter_property.text) + except: + self.__dict__["expression"] = [] + self.__dict__["expression"].append(parameter_property.text) + elif parameter_property.tag == "value": + self.__dict__[parameter_property.tag] = float(parameter_property.text) + elif parameter_property.tag == "values": + # convert to tensor + include local coordinate system + self.__dict__[parameter_property.tag] = np.fromstring(parameter_property.text, sep=' ') + else: + self.__dict__[parameter_property.tag] = parameter_property.text + def __setitem__(self, key, item): + if not key in self.__dict__: + raise RuntimeError("property is not existing") + if key == "type": + raise RuntimeError("The Type can't be changed.") + expression_counter = -1 + for parameter_property in self.xmlobject: + if parameter_property.tag == key: + if key == "expression": + expression_counter += 1 + parameter_property.text = str(item[expression_counter]) + else: + parameter_property.text = str(item) + + def __getitem__(self, key): + if not (key in ["xmlobject", "curvesobj", "trafo_matrix"]): + return self.__dict__[key] + + def __repr__(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["xmlobject", "curvesobj", "trafo_matrix"]): + newdict[k] = v + return repr(newdict) + + def __len__(self): + return len(self.__dict__) + + def __delitem__(self, key): + pass + #del self.__dict__[key] + + def clear(self): + return self.__dict__.clear() + + def copy(self): + return self.__dict__.copy() + + def has_key(self, k): + if not (k in ["xmlobject", "curvesobj", "trafo_matrix", "paramobject"]): + return k in self.__dict__ + + def update(self, *args, **kwargs): + pass + # return self.__dict__.update(*args, **kwargs) + + def keys(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["xmlobject", "curvesobj", "trafo_matrix", "paramobject"]): + newdict[k] = v + return newdict.keys() + + def values(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["xmlobject", "curvesobj", "trafo_matrix", "paramobject"]): + newdict[k] = v + return newdict.values() + + def items(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["xmlobject", "curvesobj", "trafo_matrix", "paramobject"]): + newdict[k] = v + return newdict.items() + + def pop(self, *args): + pass + #return self.__dict__.pop(*args) + + def __cmp__(self, dict_): + return self.__cmp__(self.__dict__, dict_) + + def __contains__(self, item): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["xmlobject", "curvesobj", "trafo_matrix", "paramobject"]): + newdict[k] = v + return item in newdict + + def __iter__(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["xmlobject", "curvesobj", "trafo_matrix", "paramobject"]): + newdict[k] = v + return iter(newdict) + +class Constant(Parameter_type): + def evaluate_values(self): + if "values" in self.__dict__: + values_size = len(self.__dict__["values"]) + if values_size == 1: + return self.__dict__["values"] + elif "local_coordinate_system" in self.__dict__: + if self.__dict__["local_coordinate_system"] == "true": + dim_trafo = self.trafo_matrix.shape[0] + if values_size == dim_trafo: + return np.matmul(self.trafo_matrix, self.__dict__["values"]) + elif values_size == dim_trafo**2: + values_matrix = np.zeros((dim_trafo,dim_trafo)) + mapping = {} + mapping[2] = {0: (0,0), 1: (0,1), 2: (1,0), 3: (1,1)} + mapping[3] = {0: (0,0), 1: (0,1), 2: (0,2), + 3: (1,0), 4: (1,1), 5: (1,2), + 6: (2,0), 7: (2,1), 8: (2,2)} + for i, val in enumerate(self.__dict__["values"]): + values_matrix[mapping[dim_trafo][i][0],mapping[dim_trafo][i][1]] = val + return np.matmul(self.trafo_matrix,np.matmul(values_matrix,self.trafo_matrix.transpose())) + else: + raise RuntimeError("Parameter size in combination with local coordinate system is not supported yet.") + else: + return self.__dict__["values"] + else: + return self.__dict__["value"] + +class Function(Parameter_type): + def evaluate_values(self, t=0): + if has_vtuinterface is False: + raise RuntimeError("VTUinterface is not installed") + if has_cexprtk is False: + raise RuntimeError("cexprtk is not installed") + try: + mesh = self.__dict__["mesh"] + meshfiles = self.xmlobject.getparent().getparent().findall("./mesh") + for file in meshfiles: + if mesh in file.text: + meshfile = file.text + except KeyError: + meshfile = self.xmlobject.getparent().getparent().find("./mesh").text + m = vtuIO.VTUIO(meshfile) + st = cexprtk.Symbol_Table({'x': 0.0, 'y': 0.0, 'z': 0.0, 't': t}, add_constants=True) + try: + for curve in self.curvesobj.keys(): + st.functions[curve] = self.curvesobj[curve].evaluate_values + except: + pass + dim1 = len(m.points) + dim2 = len(self.__dict__["expression"]) + if dim2 == 1: + array = np.zeros(dim1) + evaluate = cexprtk.Expression(self.__dict__["expression"][0], st) + for i in range(dim1): + st.variables["x"] = m.points[i][0] + st.variables["y"] = m.points[i][1] + st.variables["z"] = m.points[i][2] + array[i] = evaluate() + else: + array = np.zeros((dim1, dim2)) + evaluate = [] + for i in range(dim2): + evaluate.append(cexprtk.Expression(self.__dict__["expression"][i], st)) + for i in range(dim1): + for j in range(dim2): + st.variables["x"] = m.points[i][0] + st.variables["y"] = m.points[i][1] + st.variables["z"] = m.points[i][2] + array[i,j]=evaluate[j]() + return m.points, array, m + + +class MeshNode(Parameter_type): + def evaluate_values(self): + if has_vtuinterface is False: + raise RuntimeError("vtuIO is not installed") + try: + mesh = self.__dict__["mesh"] + meshfiles = self.xmlobject.getparent().getparent().findall("./mesh") + for file in meshfiles: + if mesh in file.text: + meshfile = file.text + except KeyError: + meshfile = self.xmlobject.getparent().getparent().find("./mesh") + m = vtuIO.VTUIO(meshfile) + array = m.get_point_field(self.__dict__["field_name"]) + return m.points, array, m + +class MeshElement(Parameter_type): + def evaluate_values(self): + if has_vtuinterface is False: + raise RuntimeError("vtuIO is not installed") + try: + mesh = self.__dict__["mesh"] + meshfiles = self.xmlobject.getparent().getparent().findall("./mesh") + for file in meshfiles: + if mesh in file.text: + meshfile = file.text + except KeyError: + meshfile = self.xmlobject.getparent().getparent().find("./mesh") + m = vtuIO.VTUIO(meshfile) + array = m.get_cell_field(self.__dict__["field_name"]) + return m.cell_center_points, array, m + +class CurveScaled(Parameter_type): + def evaluate_values(self, curve_coords=None): + if curve_coords is None: + t_start = self.xmlobject.getparent().getparent().find("./time_loop/processes/process/time_stepping/t_initial") + t_end = self.xmlobject.getparent().getparent().find("./time_loop/processes/process/time_stepping/t_end") + curve_coords = np.linspace(t_start, t_end, 1000, endpoint=True) + parameter_name = self.__dict__["parameter"] + curve_name = self.__dict__["curve"] + parameter_type = self.xmlobject.getparent().find(f"./parameter[name='{parameter_name}'/type").text + try: + paramobject = ctypes.cast(self.paramobject, ctypes.py_object).value + parameter_value = paramobject[parameter_name].evaluate_values() + except: + raise RuntimeError("Can't find parameter.") + if parameter_type == "Constant": + if len(parameter_value) == 1: + return parameter_value*self.curvesobj[curve_name].evaluate_values(curve_coords) + else: + return [val*self.curvesobj[curve_name].evaluate_values(curve_coords) for val in parameter_value] + else: + raise RuntimeError("This function is implemented constant parameter types only") + +#class TimeDependentHeterogeneousParameter(Parameter_type): +# pass + +class RandomFieldMeshElementParameter(Parameter_type): + pass +#class Group(Parameter_type): +# pass \ No newline at end of file diff --git a/ogs6py/classes/parameters.py b/ogs6py/classes/parameters.py index be3e139..b87bf4d 100644 --- a/ogs6py/classes/parameters.py +++ b/ogs6py/classes/parameters.py @@ -7,13 +7,16 @@ """ # pylint: disable=C0103, R0902, R0914, R0913 +import numpy as np +from lxml import etree as ET from ogs6py.classes import build_tree +from ogs6py.classes import parameter_type class Parameters(build_tree.BuildTree): """ Class for managing the parameters section of the project file. """ - def __init__(self): + def __init__(self, xmlobject=None, curvesobj=None, trafo_matrix=None): self.tree = { 'parameters': { 'tag': 'parameters', @@ -22,6 +25,174 @@ def __init__(self): 'children': {} } } + self.parameter = {} + self.xmlobject = xmlobject + self.curvesobj = curvesobj + self.trafo_matrix = trafo_matrix + if not (xmlobject is None): + for prmt in xmlobject: + for parameter_property in prmt: + if parameter_property.tag == "type": + param_type = parameter_property.text + elif parameter_property.tag == "name": + param_name = parameter_property.text + if param_type == "Constant": + self.__dict__[param_name] = parameter_type.Constant(prmt, self, curvesobj, trafo_matrix) + elif param_type == "Function": + self.__dict__[param_name] = parameter_type.Function(prmt, self, curvesobj, trafo_matrix) + elif param_type == "MeshNode": + self.__dict__[param_name] = parameter_type.MeshNode(prmt, self, curvesobj, trafo_matrix) + elif param_type == "MeshElement": + self.__dict__[param_name] = parameter_type.MeshElement(prmt, self, curvesobj, trafo_matrix) + elif param_type == "CurveScaled": + self.__dict__[param_name] = parameter_type.CurveScaled(prmt, self, curvesobj, trafo_matrix) +# elif param_type == "TimeDependentHeterogeneousParameter": +# self.__dict__[param_name] = parameter_type.TimeDependentHeterogeneousParameter(prmt) + elif param_type == "RandomFieldMeshElementParameter": + self.__dict__[param_name] = parameter_type.RandomFieldMeshElementParameter(prmt, self, curvesobj, trafo_matrix) +# elif param_type == "Group": +# self.__dict__[param_name] = parameter_type.Group(prmt) + + def __checkparameter(self, dictionary): + required = {"Constant": ["name", "type"], + "Function": ["name", "type", "expression"], + "MeshNode": ["name", "type", "field_name"], + "MeshElement": ["name", "type", "field_name"], + "CurvedScaled": ["name", "type", "curve", "parameter"], + "TimeDependentHeterogeneousParameter": ["name", "type", "time_series"], + "RandomFieldMeshElementParameter": ["name", "type","field_name", "range", "seed"], + "Group": ["name", "type", "group_id_property"]} + optional = {"Constant": ["value", "values"], + "Function": ["mesh"], + "MeshNode": ["mesh"], + "MeshElement": ["mesh"], + "CurvedScaled": ["mesh"], + "TimeDependentHeterogeneousParameter": ["mesh"], + "RandomFieldMeshElementParameter": ["mesh"], + "Group": ["mesh"]} + for k, v in dictionary.items(): + if not k in (required[dictionary["type"]]+optional[dictionary["type"]]): + raise RuntimeError(f"{k} is not a valid property field for the specified type.") + for entry in required[dictionary["type"]]: + if not entry in dictionary: + raise RuntimeError(f"{entry} is required for the specified type.") + if dictionary["type"] == "Constant": + if not (("value" in dictionary) or ("values" in dictionary)): + raise RuntimeError("The Constant parameter requires value or values to be specified.") + + + def __setitem__(self, key, item): + if not isinstance(item, dict): + raise RuntimeError("Item must be a dictionary") + if len(item) == 0: + self.__delitem__(key) + return + self.__checkparameter(item) + if key in self.__dict__: + self.__delitem__(key) + prmt_obj = ET.SubElement(self.xmlobject, "parameter") + for k, v in item.items(): + if k == "expression": + q = [] + for subentry in v: + q.append(ET.SubElement(prmt_obj, "expression")) + q[-1].text = subentry + else: + q = ET.SubElement(prmt_obj, k) + q.text = v + if item["type"] == "Constant": + self.__dict__[key] = parameter_type.Constant(prmt_obj, self, self.curvesobj, self.trafo_matrix) + elif item["type"] == "Function": + self.__dict__[key] = parameter_type.Function(prmt_obj, self, self.curvesobj, self.trafo_matrix) + elif item["type"] == "MeshNode": + self.__dict__[key] = parameter_type.MeshNode(prmt_obj, self, self.curvesobj, self.trafo_matrix) + elif item["type"] == "MeshElement": + self.__dict__[key] = parameter_type.MeshElement(prmt_obj,self, self.curvesobj, self.trafo_matrix) + elif item["type"] == "CurveScaled": + self.__dict__[key] = parameter_type.CurveScaled(prmt_obj, self, self.curvesobj, self.trafo_matrix) +# elif item["type"] == "TimeDependentHeterogeneousParameter": +# self.__dict__[param_name] = parameter_type.TimeDependentHeterogeneousParameter(prmt) + elif item["type"] == "RandomFieldMeshElementParameter": + self.__dict__[key] = parameter_type.RandomFieldMeshElementParameter(prmt_obj, self, self.curvesobj, self.trafo_matrix) +# elif item["type"] == "Group": +# self.__dict__[param_name] = parameter_type.Group(prmt) + return prmt_obj + + + def __getitem__(self, key): + if not (key in ["tree", "parameter", "xmlobject", "curvesobj", "trafo_matrix"]): + return self.__dict__[key] + + def __repr__(self): + newdict = dict() + for k, v in self.__dict__.items(): + if not (k in ["tree", "parameter", "name", "xmlobject", "curvesobj", "trafo_matrix"]): + newdict[k] = v + return repr(newdict) + + def __len__(self): + return len(self.__dict__) + + def __delitem__(self, key): + obj = self.__dict__[key].xmlobject + obj.getparent().remove(obj) + del self.__dict__[key] + + def clear(self): + return self.__dict__.clear() + + def copy(self): + return self.__dict__.copy() + + def has_key(self, k): + if not (k in ["tree", "parameter", "xmlobject", "curvesobj", "trafo_matrix"]): + return k in self.__dict__ + + def update(self, *args, **kwargs): + pass + # return self.__dict__.update(*args, **kwargs) + + def keys(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "parameter", "xmlobject", "curvesobj", "trafo_matrix"]): + newdict[k] = v + return newdict.keys() + + def values(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "parameter", "xmlobject", "curvesobj", "trafo_matrix"]): + newdict[k] = v + return newdict.values() + + def items(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "parameter", "xmlobject", "curvesobj", "trafo_matrix"]): + newdict[k] = v + return newdict.items() + + def pop(self, *args): + pass + #return self.__dict__.pop(*args) + + def __cmp__(self, dict_): + return self.__cmp__(self.__dict__, dict_) + + def __contains__(self, item): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "parameter", "xmlobject", "curvesobj", "trafo_matrix"]): + newdict[k] = v + return item in newdict + + def __iter__(self): + newdict = {} + for k, v in self.__dict__.items(): + if not (k in ["tree", "parameter", "xmlobject", "curvesobj", "trafo_matrix"]): + newdict[k] = v + return iter(newdict) def add_parameter(self, **args): """ diff --git a/ogs6py/ogs.py b/ogs6py/ogs.py index 5aca83b..a6517c6 100644 --- a/ogs6py/ogs.py +++ b/ogs6py/ogs.py @@ -18,6 +18,7 @@ import shutil import pandas as pd from lxml import etree as ET +from fastcore.utils import * from ogs6py.classes import (display, geo, mesh, python_script, processes, media, timeloop, local_coordinate_system, parameters, curves, processvars, linsolvers, nonlinsolvers) import ogs6py.log_parser.log_parser as parser @@ -26,31 +27,34 @@ class OGS: """Class for an OGS6 model. - In this class everything for an OGS5 model can be specified. + In this class everything for an OGS6 model can be specified. Parameters ---------- - PROJECT_FILE : `str`, optional + prjfile : `str`, optional Filename of the output project file Default: default.prj - INPUT_FILE : `str`, optional + input_file : `str`, optional Filename of the input project file - XMLSTRING : `str`,optional - OMP_NUM_THREADS : `int`, optional + xmlstring : `str`,optional + omp_num_threads : `int`, optional Sets the environmentvariable before OGS execution to restrict number of OMP Threads - VERBOSE : `bool`, optional + verbose : `bool`, optional Default: False """ - def __init__(self, **args): + def __init__(self, input_file=None, prjfile="default.prj", xmlstring=None, verbose=False, omp_num_threads=None, **args): self.geo = geo.Geo() self.mesh = mesh.Mesh() self.pyscript = python_script.PythonScript() self.processes = processes.Processes() self.media = media.Media() self.timeloop = timeloop.TimeLoop() - self.local_coordinate_system = local_coordinate_system.LocalCoordinateSystem() - self.parameters = parameters.Parameters() - self.curves = curves.Curves() + self.__local_coordinate_system = None + self.__parameters = None + self.__curves = None + self.__local_coordinate_system_obj = None + self.__parameters_obj = None + self.__curves_obj = None self.processvars = processvars.ProcessVars() self.linsolvers = linsolvers.LinSolvers() self.nonlinsolvers = nonlinsolvers.NonLinSolvers() @@ -61,32 +65,32 @@ def __init__(self, **args): self.include_elements = [] self.include_files = [] self.add_includes = [] + store_attr() + # **args only fror backwards compatibility if "VERBOSE" in args: self.verbose = args["VERBOSE"] else: self.verbose = False if "OMP_NUM_THREADS" in args: - self.threads = args["OMP_NUM_THREADS"] + self.omp_num_threads = args["OMP_NUM_THREADS"] else: - self.threads = None + self.omp_num_threads = None if "PROJECT_FILE" in args: self.prjfile = args['PROJECT_FILE'] - else: + if self.prjfile is None: print("PROJECT_FILE for output not given. Calling it default.prj.") self.prjfile = "default.prj" if "INPUT_FILE" in args: - if os.path.isfile(args['INPUT_FILE']) is True: - self.inputfile = args['INPUT_FILE'] + self.input_file = args['INPUT_FILE'] + if self.input_file is not None: + if os.path.isfile(self.input_file) is True: _ = self._get_root() if self.verbose is True: display.Display(self.tree) else: raise RuntimeError(f"Input project file {args['INPUT_FILE']} not found.") - else: - self.inputfile = None if "XMLSTRING" in args: - root = ET.fromstring(args['XMLSTRING']) - self.tree = ET.ElementTree(root) + self.xmlstring = args["XMLSTRING"] def __dict2xml(self, parent, dictionary): for entry in dictionary: @@ -115,8 +119,11 @@ def __replace_blocks_by_includes(self): def _get_root(self): if self.tree is None: - if self.inputfile is not None: - self.tree = ET.parse(self.inputfile) + if self.input_file is not None: + self.tree = ET.parse(self.input_file) + elif self.xmlstring is not None: + root = ET.fromstring(self.xmlstring) + self.tree = ET.ElementTree(root) else: self.build_tree() root = self.tree.getroot() @@ -136,6 +143,45 @@ def _get_root(self): self.include_elements.append(child) return root + @property + def parameters(self): + try: + paramobj = self.tree.find("./parameters") + if not (paramobj == self.__parameters_obj): + self.__parameters_obj = paramobj + self.__parameters = parameters.Parameters(xmlobject=paramobj, curvesobj=self.curves, trafo_matrix=self.local_coordinate_system.R) + except AttributeError: + paramobj = None + if self.__parameters is None: + self.__parameters = parameters.Parameters(xmlobject=paramobj, curvesobj=self.curves, trafo_matrix=self.local_coordinate_system.R) + return self.__parameters + + @property + def curves(self): + try: + curveobj = self.tree.find("./curves") + if not (curveobj == self.__curves_obj): + self.__curves_obj = curveobj + self.__curves = curves.Curves(xmlobject=curveobj) + except AttributeError: + curveobj = None + if self.__curves is None: + self.__curves = curves.Curves(xmlobject=curveobj) + return self.__curves + + @property + def local_coordinate_system(self): + try: + lcsobj = self.tree.find("./local_coordinate_system") + if not (lcsobj == self.__local_coordinate_system_obj): + self.__local_coordinate_system_obj = lcsobj + self.__local_coordinate_system = local_coordinate_system.LocalCoordinateSystem(xmlobject=lcsobj) + except AttributeError: + lcsobj = None + if self.__local_coordinate_system is None: + self.__local_coordinate_system = local_coordinate_system.LocalCoordinateSystem(xmlobject=lcsobj) + return self.__local_coordinate_system + @classmethod def _get_parameter_pointer(cls, root, name, xpath): params = root.findall(xpath) @@ -566,10 +612,10 @@ def run_model(self, logfile="out.log", path=None, args=None, container_path=None """ ogs_path = "" - if self.threads is None: + if self.omp_num_threads is None: env_export = "" else: - env_export = f"export OMP_NUM_THREADS={self.threads} && " + env_export = f"export OMP_NUM_THREADS={self.omp_num_threads} && " if not container_path is None: container_path = os.path.expanduser(container_path) if os.path.isfile(container_path) is False: diff --git a/setup.py b/setup.py index 1ea70ae..0349813 100644 --- a/setup.py +++ b/setup.py @@ -55,5 +55,6 @@ def find_version(*file_paths): include_package_data=True, python_requires='>=3.8', install_requires=["lxml","pandas"], + extras_require={"parameter_access": ["VTUinterface", "cexprtk"]}, py_modules=["ogs6py/ogs","ogs6py/log_parser/log_parser", "ogs6py/log_parser/common_ogs_analyses", "ogs6py/ogs_regexes/ogs_regexes"], packages=["ogs6py/classes","ogs6py/log_parser","ogs6py/ogs_regexes"])