Skip to content

Commit

Permalink
hotfix: fix rescale_values (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlsneto authored Jan 26, 2023
1 parent 76f1b08 commit 49dc1c0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 33 deletions.
2 changes: 1 addition & 1 deletion cereja/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from . import experimental
from ._requests import request

VERSION = "1.8.2.final.0"
VERSION = "1.8.3.final.0"

__version__ = get_version_pep440_compliant(VERSION)

Expand Down
53 changes: 21 additions & 32 deletions cereja/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@
# Needed init configs
from ..config.cj_types import ClassType, FunctionType, Number

try:
import numpy as np

_has_numpy = True
except:
_has_numpy = False
__all__ = [
"CjTest",
"camel_to_snake",
Expand Down Expand Up @@ -222,12 +216,12 @@ def truncate(text: Union[str, bytes], k=15):
@param k: natural numbers, default is 4
"""
assert isinstance(text, (str, bytes)), TypeError(
f"{type(text)} isn't valid. Expected str or bytes"
f"{type(text)} isn't valid. Expected str or bytes"
)
if k > len(text) or k <= 4:
return text
n = int(
(k - 4) / 2
(k - 4) / 2
) # k is the max length of text, 4 is the length of truncate symbol
trunc_chars = "...." if isinstance(text, str) else b"...."
return text[:n] + trunc_chars + text[-n:]
Expand Down Expand Up @@ -439,7 +433,7 @@ def import_string(dotted_path):
return getattr(module, class_name)
except AttributeError as err:
raise ImportError(
f"Module {module_path} does not define a {class_name} attribute/class"
f"Module {module_path} does not define a {class_name} attribute/class"
) from err


Expand Down Expand Up @@ -511,7 +505,7 @@ def module_references(instance: types.ModuleType, **kwargs) -> dict:
:return: List[str]
"""
assert isinstance(
instance, types.ModuleType
instance, types.ModuleType
), "You need to submit a module instance."
logger.debug(f"Checking module {instance.__name__}")
definitions = {}
Expand Down Expand Up @@ -639,7 +633,7 @@ def n_checks(self):
@property
def _instance_obj_attrs(self):
return filter(
lambda attr_: attr_.__contains__("__") is False, dir(self._instance_obj)
lambda attr_: attr_.__contains__("__") is False, dir(self._instance_obj)
)

def _get_attr_obj(self, attr_: str):
Expand Down Expand Up @@ -736,7 +730,7 @@ def run(self, current_value):

def _valid_attr(self, attr_name: str):
assert hasattr(
self._instance_obj, attr_name
self._instance_obj, attr_name
), f"{self.__prefix_attr_err.format(attr_=repr(attr_name))} isn't defined."
return attr_name

Expand Down Expand Up @@ -768,11 +762,11 @@ def check_all(self):
@classmethod
def _get_class_test(cls, ref):
func_tests = "".join(
cls.__template_unittest_function.format(func_name=i)
for i in list_methods(ref)
cls.__template_unittest_function.format(func_name=i)
for i in list_methods(ref)
)
return cls.__template_unittest_class.format(
class_name=ref.__name__, func_tests=func_tests
class_name=ref.__name__, func_tests=func_tests
)

@classmethod
Expand Down Expand Up @@ -806,7 +800,7 @@ def build_test(cls, reference):
module_func_test = "".join(module_func_test)
tests = [
cls.__template_unittest_class.format(
class_name="Module", func_tests=module_func_test
class_name="Module", func_tests=module_func_test
)
] + tests
return cls.__template_unittest.format(tests="\n".join(tests))
Expand Down Expand Up @@ -867,14 +861,10 @@ def _rescale_up(values, k, fill_with=None, filling="inner"):

def _interpolate(values, k):
if isinstance(values, list):
# TODO: need fix Matrix.
if _has_numpy:
values = np.array(values)
else:
from ..array import Matrix
from ..array import Matrix

# because true_div ...
values = Matrix(values)
# because true_div ...
values = Matrix(values)
size = len(values)

first_position = 0
Expand Down Expand Up @@ -906,7 +896,7 @@ def rescale_values(
interpolation: bool = False,
fill_with=None,
filling="inner",
) -> Union[List[Any], 'numpy.ndarray']: # noqa: F821
) -> List[Any]:
"""
Resizes a list of values
eg.
Expand All @@ -933,16 +923,15 @@ def rescale_values(
@param filling: in case of scale up, you can define how the filling will be (pre, inner, post). 'inner' is default.
@return: rescaled list of values.
"""
_type = type(values)

if interpolation:
result = _type(list(_interpolate(values, granularity)))
result = list(_interpolate(values, granularity))
else:
if len(values) >= granularity:
result = _type(list(_rescale_down(values, granularity)))
result = list(_rescale_down(values, granularity))
else:
result = _type(
list(_rescale_up(values, granularity, fill_with=fill_with, filling=filling))
result = list(
_rescale_up(values, granularity, fill_with=fill_with, filling=filling)
)

assert (
Expand Down Expand Up @@ -1043,7 +1032,7 @@ def sort_dict(

def list_to_tuple(obj):
assert isinstance(
obj, (list, set, tuple)
obj, (list, set, tuple)
), f"Isn't possible convert {type(obj)} into {tuple}"
result = []
for i in obj:
Expand All @@ -1066,7 +1055,7 @@ def dict_values_len(obj, max_len=None, min_len=None, take_len=False):

def dict_to_tuple(obj):
assert isinstance(
obj, (dict, set)
obj, (dict, set)
), f"Isn't possible convert {type(obj)} into {tuple}"
result = []
if isinstance(obj, set):
Expand Down Expand Up @@ -1198,7 +1187,7 @@ def get_batch_strides(data, kernel_size, strides=1, fill_=True, take_index=False
batches = batches[strides:]
if len(batches):
yield rescale_values(
batches, granularity=kernel_size, filling="post"
batches, granularity=kernel_size, filling="post"
) if fill_ else batches


Expand Down

0 comments on commit 49dc1c0

Please sign in to comment.