Skip to content

Commit

Permalink
test for find functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinstadler committed Dec 15, 2023
1 parent ea6de8d commit 5b670a8
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 87 deletions.
74 changes: 47 additions & 27 deletions pymrio/core/mriosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def rename_Y_categories(self, Y_categories):
return self

def find(self, term):
""" Looks for term in index, sectors, regions, Y_categories
"""Looks for term in index, sectors, regions, Y_categories
Mostly useful for a quick check if entry is present.
Expand All @@ -731,25 +731,44 @@ def find(self, term):
Empty keys are ommited.
The values can be used directly on one of the DataFrames with .loc
"""
index_find = ioutil.index_contains(self.get_index(as_dict=False), find_all=term)
reg_find = ioutil.index_contains(self.get_regions(), find_all=term)
sector_find = ioutil.index_contains(self.get_sectors(), find_all=term)
Y_find = ioutil.index_contains(self.get_Y_categories(), find_all=term)

res_dict = dict()
if len(index_find) > 0:
res_dict["index"] = index_find
if len(reg_find) > 0:
res_dict["region"] = reg_find
if len(sector_find) > 0:
res_dict["sector"] = sector_find
if len(Y_find) > 0:
res_dict["Y_category"] = Y_find
try:
index_find = ioutil.index_contains(self.get_index(as_dict=False), find_all=term)
if len(index_find) > 0:
res_dict["index"] = index_find
except: # noqa: E722
pass
try:
reg_find = ioutil.index_contains(self.get_regions(), find_all=term)
if len(reg_find) > 0:
res_dict["regions"] = reg_find
except: # noqa: E722
pass
try:
sector_find = ioutil.index_contains(self.get_sectors(), find_all=term)
if len(sector_find) > 0:
res_dict["sectors"] = sector_find
except: # noqa: E722
pass
try:
Y_find = ioutil.index_contains(self.get_Y_categories(), find_all=term)
if len(Y_find) > 0:
res_dict["Y_categories"] = Y_find
except: # noqa: E722
pass
try:
for ext in self.get_extensions(data=False):
ext_index_find = ioutil.index_contains(
getattr(self, ext).get_index(as_dict=False),
find_all=term)
if len(ext_index_find) > 0:
res_dict[ext + "_index"] = ext_index_find
except: # noqa: E722
pass

return res_dict



# API classes
class Extension(BaseSystem):
"""Class which gathers all information for one extension of the IOSystem
Expand Down Expand Up @@ -2764,20 +2783,21 @@ def characterize(extension, char_factors, fallback=None):
if not ioutil.check_if_long(extension, value_col="value"):
extension = ioutil.convert_to_long(extension)


# TODO: move to util and write test
# TODO: make method for core, each extension and all extensions and test
# def regex_match(df_ix, **kwargs):
# """ Match index of df with regex
#
#
# The index levels need to be named (df.index.name needs to be set for all levels).
#
# Note
# -----
# The matching is done with str.fullmatch.
# Thus the passed pattern needs to match the full entry.
# This can be converted into matching only the
# This can be converted into matching only the
# beginning (simulating str.match) by appending '.*' to the pattern.
# To get the same behaviour as str.contains, append '.*' to
# To get the same behaviour as str.contains, append '.*' to
# the beginning and end of the pattern.
#
# Arguments of fullmatch are set to case=True, flags=0, na=False.
Expand All @@ -2791,7 +2811,7 @@ def characterize(extension, char_factors, fallback=None):
# df_ix : pd.DataFrame, pd.Series, pd.Index or pd.MultiIndex
# Rows/Index will be matched
# kwargs : dict
# The regex to match. The keys are the index names,
# The regex to match. The keys are the index names,
# the values are the regex to match.
# If the entry is not in index name, it is ignored silently.
#
Expand All @@ -2805,14 +2825,14 @@ def characterize(extension, char_factors, fallback=None):
# for key, value in kwargs.items():
# try:
# if type(df_ix) in [pd.DataFrame, pd.Series]:
# df_ix = df_ix[df_ix.index.get_level_values(key).str.fullmatch(value,
# df_ix = df_ix[df_ix.index.get_level_values(key).str.fullmatch(value,
# case=True,
# flags=0,
# flags=0,
# na=False)]
# elif type(df_ix) in [pd.Index, pd.MultiIndex]:
# df_ix = df_ix[df_ix.get_level_values(key).str.fullmatch(value,
# df_ix = df_ix[df_ix.get_level_values(key).str.fullmatch(value,
# case=True,
# flags=0,
# flags=0,
# na=False)]
# except KeyError:
# pass
Expand All @@ -2825,6 +2845,7 @@ def characterize(extension, char_factors, fallback=None):
#
# mm = regex_match(tt.emissions.F, compartment='air', abc="raba")


def match_and_convert(
src=None, bridge=None, src_match_col=None, bridge_match_col=None, agg_method=None
):
Expand All @@ -2840,9 +2861,9 @@ def match_and_convert(
TODO: Assumption on the data format of src:
all non numerical columns are set as index, table can be in long or wide format, all "proper" columns are numerical
also set string columns as index, even if they are not used for matching (e.g. units).
also set string columns as index, even if they are not used for matching (e.g. units).
Parameters
----------
Expand Down Expand Up @@ -2877,7 +2898,6 @@ def match_and_convert(

tt = pymrio.load_test()


# What we need in cc_headers
# src_match_col_1
# bridge_match_col_1
Expand Down
27 changes: 7 additions & 20 deletions pymrio/tools/ioutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import requests
import urllib3

from pymrio.core.constants import (DEFAULT_FILE_NAMES, LONG_VALUE_NAME,
PYMRIO_PATH)
from pymrio.core.constants import DEFAULT_FILE_NAMES, LONG_VALUE_NAME, PYMRIO_PATH


def is_vector(inp):
Expand Down Expand Up @@ -154,9 +153,7 @@ def get_file_para(path, path_in_arc=""):

if zipfile.is_zipfile(str(path)):
with zipfile.ZipFile(file=str(path)) as zf:
para_file_content = json.loads(
zf.read(para_file_full_path).decode("utf-8")
)
para_file_content = json.loads(zf.read(para_file_full_path).decode("utf-8"))
else:
with open(para_file_full_path, "r") as pf:
para_file_content = json.load(pf)
Expand Down Expand Up @@ -281,9 +278,7 @@ def diagonalize_columns_to_sectors(
"""

sectors = df.index.get_level_values(sector_index_level).unique()
sector_name = (
sector_index_level if type(sector_index_level) is str else "sector"
)
sector_name = sector_index_level if type(sector_index_level) is str else "sector"

new_col_index = [
tuple(list(orig) + [new]) for orig in df.columns for new in sectors
Expand Down Expand Up @@ -402,8 +397,7 @@ def set_block(arr, arr_block):
)
if nr_row / nr_row_block != nr_col / nr_col_block:
raise ValueError(
"Block array can not be filled as "
"diagonal blocks in the given array"
"Block array can not be filled as " "diagonal blocks in the given array"
)

arr_out = arr.copy()
Expand Down Expand Up @@ -519,8 +513,7 @@ def build_agg_vec(agg_vec, **source):
]
else:
agg_dict[entry] = [
None if ee == "None" else ee
for ee in _tmp[:, -1].tolist()
None if ee == "None" else ee for ee in _tmp[:, -1].tolist()
]
break
else:
Expand Down Expand Up @@ -619,11 +612,7 @@ def read_first_lines(filehandle):

sep_aly_lines = [
sorted(
[
(line.count(sep), sep)
for sep in potential_sep
if line.count(sep) > 0
],
[(line.count(sep), sep) for sep in potential_sep if line.count(sep) > 0],
key=lambda x: x[0],
reverse=True,
)
Expand Down Expand Up @@ -938,9 +927,7 @@ def _index_regex_matcher(_dfs_idx, _method, _find_all=None, **kwargs):
"""
if _method not in ["contains", "match", "fullmatch"]:
raise ValueError(
'Method must be one of "contains", "match", "fullmatch"'
)
raise ValueError('Method must be one of "contains", "match", "fullmatch"')

if _find_all is not None:
if type(_dfs_idx) in [pd.DataFrame, pd.Series]:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def test_diag_stressor(fix_testmrio):
# )
#


def test_characterize_extension(fix_testmrio):
factors = pd.read_csv(
Path(PYMRIO_PATH["test_mrio"] / Path("concordance") / "emissions_charact.tsv"),
Expand Down Expand Up @@ -494,6 +495,22 @@ def test_reset_to_coefficients(fix_testmrio):
assert tt.Z is None
assert tt.emissions.F is None

def test_find(fix_testmrio):
tt = fix_testmrio.testmrio

all_found = tt.find(".*")
assert all(all_found["sectors"] == tt.get_sectors())
assert all(all_found["regions"] == tt.get_regions())
assert all(all_found["Y_categories"] == tt.get_Y_categories())
assert all(all_found["index"] == tt.get_index())

for ext in tt.get_extensions(data=False):
assert all(all_found[ext + "_index"] == tt.__dict__[ext].get_index())

ext_find = tt.find("air")
assert "sectors" not in ext_find.keys()
assert "regions" not in ext_find.keys()
assert "Y_categories" not in ext_find.keys()

def test_direct_account_calc(fix_testmrio):
orig = fix_testmrio.testmrio
Expand Down
58 changes: 18 additions & 40 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@
from pymrio.tools.ioutil import find_first_number # noqa
from pymrio.tools.ioutil import set_block # noqa
from pymrio.tools.ioutil import sniff_csv_format # noqa
from pymrio.tools.ioutil import (diagonalize_columns_to_sectors, # noqa
filename_from_url, index_contains,
index_fullmatch, index_match)
from pymrio.tools.ioutil import (
diagonalize_columns_to_sectors, # noqa
filename_from_url,
index_contains,
index_fullmatch,
index_match,
)


@pytest.fixture()
def csv_test_files_content():
test_para = namedtuple(
"test_para", ["text", "sep", "header_rows", "index_col"]
)
test_para = namedtuple("test_para", ["text", "sep", "header_rows", "index_col"])

class example_csv_content:
test_contents = [
Expand Down Expand Up @@ -207,9 +209,7 @@ def test_diagonalize_columns_to_sectors():
inp_df = pd.DataFrame(data=inp_array, index=reg_sec_index, columns=regions)
inp_df.columns.names = ["region"]

out_df = pd.DataFrame(
data=out_array, index=inp_df.index, columns=reg_sec_index
)
out_df = pd.DataFrame(data=out_array, index=inp_df.index, columns=reg_sec_index)

diag_df = diagonalize_columns_to_sectors(inp_df)
pdt.assert_frame_equal(diag_df, out_df)
Expand Down Expand Up @@ -291,65 +291,45 @@ def test_util_regex():
mdx_match = index_fullmatch(test_index, region=".*2", sector="cc")

assert (
len(
mdx_match.get_level_values("region")
.unique()
.difference({"c2", "b2"})
)
== 0
len(mdx_match.get_level_values("region").unique().difference({"c2", "b2"})) == 0
)

test_ds = test_df.foo
ds_match = index_fullmatch(test_ds, sector="aa")

assert ds_match.index.get_level_values("sector").unique() == ["aa"]
assert all(
ds_match.index.get_level_values("region").unique()
== ["a1", "b1", "c2", "b2"]
ds_match.index.get_level_values("region").unique() == ["a1", "b1", "c2", "b2"]
)

idx_match = index_fullmatch(
test_index.get_level_values("region"), region=".*2"
)
idx_match = index_fullmatch(test_index.get_level_values("region"), region=".*2")
assert (
len(
idx_match.get_level_values("region")
.unique()
.difference({"c2", "b2"})
)
== 0
len(idx_match.get_level_values("region").unique().difference({"c2", "b2"})) == 0
)

# test with empty dataframes
test_empty = pd.DataFrame(index=test_index)
df_match_empty = index_fullmatch(test_empty, region=".*b.*", sector=".*b.*")

assert all(
df_match_empty.index.get_level_values("region").unique() == ["b1", "b2"]
)
assert all(df_match_empty.index.get_level_values("region").unique() == ["b1", "b2"])
assert df_match_empty.index.get_level_values("sector").unique() == ["bb"]

# test with empty index
empty_index = pd.MultiIndex.from_product(
[[], []], names=["region", "sector"]
)
empty_index = pd.MultiIndex.from_product([[], []], names=["region", "sector"])

assert len(index_fullmatch(empty_index, region=".*", sector="cc")) == 0

# 2. test the contains functionality

df_match_contains = index_contains(test_df, region="1", sector="c")
assert all(
df_match_contains.index.get_level_values("region").unique()
== ["a1", "b1"]
df_match_contains.index.get_level_values("region").unique() == ["a1", "b1"]
)
assert df_match_contains.index.get_level_values("sector").unique() == ["cc"]

# 3. test the match functionality
df_match_match = index_match(test_df, region="b")
assert all(
df_match_match.index.get_level_values("region").unique() == ["b1", "b2"]
)
assert all(df_match_match.index.get_level_values("region").unique() == ["b1", "b2"])

# 4. test the findall functionality
df_match_findall = index_contains(test_df, find_all="c")
Expand All @@ -358,6 +338,4 @@ def test_util_regex():

# 5. test wrong input
with pytest.raises(ValueError):
index_fullmatch(
"foo", region="a.*", sector=".*b.*", not_present_column="abc"
)
index_fullmatch("foo", region="a.*", sector=".*b.*", not_present_column="abc")

0 comments on commit 5b670a8

Please sign in to comment.