Skip to content

Commit

Permalink
chore: add Rotation utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jlsneto committed Nov 24, 2023
1 parent 73d847c commit 5ccc1bd
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 71 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ cj.array.rand_n(0.0, 2.0, n=3) # [0.3001196087729699, 0.639679494102923, 1.0602
cj.array.rand_n(1, 10) # 5.086403830031244
cj.array.array_randn((3, 3,
3)) # [[[0.015077210355770374, 0.014298110484612511, 0.030410666810216064], [0.029319083335697604, 0.0072365209507707666, 0.010677361074992], [0.010576754075922935, 0.04146379877648334, 0.02188348813336284]], [[0.0451851551098092, 0.037074906805326824, 0.0032484586475421007], [0.025633380630695347, 0.010312669541918484, 0.0373624007621097], [0.047923908102496145, 0.0027939333359724224, 0.05976224377251878]], [[0.046869510719106486, 0.008325638358172866, 0.0038702998343255893], [0.06475268683502387, 0.0035638592537234623, 0.06551037943638163], [0.043317416824708604, 0.06579372884523939, 0.2477564291871006]]]
cj.array.group_items_in_batches(items=[1, 2, 3, 4], items_per_batch=3, fill=0) # [[1, 2, 3], [4, 0, 0]]
cj.chunk(data=[1, 2, 3, 4], batch_size=3, fill_with=0) # [[1, 2, 3], [4, 0, 0]]
cj.array.remove_duplicate_items(['hi', 'hi', 'ih']) # ['hi', 'ih']
cj.array.get_cols([['line1_col1', 'line1_col2'],
['line2_col1', 'line2_col2']]) # [['line1_col1', 'line2_col1'], ['line1_col2', 'line2_col2']]
Expand Down
89 changes: 29 additions & 60 deletions cereja/array/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
"flatten",
"get_cols",
"get_shape",
"get_shape_recursive",
"group_items_in_batches",
"is_empty",
"rand_n",
"rand_uniform",
Expand All @@ -55,18 +53,19 @@
]

from ..utils import is_iterable, is_sequence, is_numeric_sequence, chunk, dict_to_tuple
from ..utils.decorators import depreciation, time_exec
from ..utils.decorators import time_exec

logger = logging.getLogger(__name__)


def shape_is_ok(
sequence: Union[Sequence[Any], Any], expected_shape: Tuple[int, ...]
sequence: Union[Sequence[Any], Any]
) -> bool:
"""
Check the number of items the array has and compare it with the shape product
"""
try:
expected_shape = get_shape(sequence)
sequence_len = len(flatten(sequence))
except Exception as err:
logger.debug(f"Error when trying to compare shapes. {err}")
Expand All @@ -83,42 +82,28 @@ def is_empty(sequence: Sequence) -> bool:
return False


def get_shape(sequence: Sequence[Any]) -> Tuple[Union[int, None], ...]:
def get_shape(sequence: Sequence) -> Tuple[Union[int, None], ...]:
"""
Responsible for analyzing the depth of a sequence
Get the shape (dimensions and sizes) of a nested sequence (like a list or tuple).
:param sequence: Is sequence of values.
:return: number of dimensions
"""
if is_empty(sequence):
return (None,)
wkij = []
while True:
if is_sequence(sequence) and not is_empty(sequence):
wkij.append(len(sequence))
sequence = sequence[0]
continue
break
return tuple(wkij)
If the sequence is empty or not uniform (e.g., sub-sequences have different lengths),
returns None for the dimension(s) where uniformity breaks.
Parameters:
sequence (Sequence): A sequence of values, possibly nested.
def get_shape_recursive(
sequence: Sequence[Any], wki: Tuple[int, ...] = None
) -> Tuple[int, ...]:
Returns:
Tuple[Union[int, None], ...]: A tuple representing the size of each dimension of the sequence.
"""
[!] Never use recursion in python if it is possible to exceed 997 calls [!]
if not sequence if not type(sequence).__name__ == "ndarray" else len(sequence) == 0:
return (None,)

[!] Function for teaching purposes [!]
shape = []
while is_sequence(sequence) and not is_empty(sequence):
shape.append(len(sequence))
sequence = sequence[0] if len(sequence) else None

:param sequence: Is sequence of values.
:param wki: stores value for each dimension
"""
if wki is None:
wki = []
if is_sequence(sequence):
wki += [len(sequence)]
return get_shape_recursive(sequence[0], wki)
return tuple(wki)
return tuple(shape)


def reshape(sequence: Sequence, shape):
Expand Down Expand Up @@ -162,7 +147,7 @@ def array_gen(

is_seq = is_sequence(v)

allow_reshape = shape_is_ok(v, shape) and is_seq
allow_reshape = shape_is_ok(v) and is_seq

if not is_seq:
v = [v if v else 0.0]
Expand Down Expand Up @@ -401,29 +386,6 @@ def array_randn(shape: Tuple[int, ...], *args, **kwargs) -> List[Union[float, An
return array_gen(shape=shape, v=rand_n_values)


@depreciation(alternative="cereja.utils.chunk")
def group_items_in_batches(
items: List[Any], items_per_batch: int = 0, fill: Any = None
) -> List[List[Any]]:
"""
Responsible for grouping items in batch taking into account the quantity of items per batch
e.g.
>>> group_items_in_batches(items=[1,2,3,4], items_per_batch=3)
[[1, 2, 3], [4]]
>>> group_items_in_batches(items=[1,2,3,4], items_per_batch=3, fill=0)
[[1, 2, 3], [4, 0, 0]]
:param items: list of any values
:param items_per_batch: number of items per batch
:param fill: fill examples when items is not divisible by items_per_batch, default is None
:return:
"""
from cereja.utils import chunk

return chunk(data=items, batch_size=items_per_batch, fill_with=fill)


def remove_duplicate_items(items: Optional[list]) -> Any:
"""
remove duplicate items in an item list or duplicate items list of list
Expand Down Expand Up @@ -455,22 +417,29 @@ def get_cols(sequence: Union[Sequence, "Matrix"]):
return list(zip(*sequence))


def prod(sequence: Sequence[Number]) -> Number:
def prod(sequence: Sequence[Number], start=1) -> Number:
"""
Calculates the product of the values.
Calculate the product of all the elements in the input iterable.
The default start value for the product is 1.
This function is intended specifically for use with numeric values and may
reject non-numeric types.
:param sequence: Is a sequence of numbers.
:param start: is a number
:return:
"""
if hasattr(math, "prod"):
# New in Python 3.8
return math.prod(sequence, start=start)
# alternative for Python < 3.8
if not is_sequence(sequence):
raise TypeError(
f"Value of {sequence} is not valid. Please send a numeric list."
)

return reduce((lambda x, y: x * y), sequence)
return reduce((lambda x, y: x * y), [start, *sequence])


def sub(sequence: Sequence[Number]) -> Number:
Expand Down
4 changes: 3 additions & 1 deletion cereja/config/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
from abc import abstractmethod, ABCMeta
import os

__all__ = ["BasicConfig", "BASE_DIR"]
__all__ = ["BasicConfig", "BASE_DIR", "PYTHON_VERSION"]

# using by utils.module_references
_exclude = ["console_logger", "cj_modules_dotted_path"]

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

PYTHON_VERSION = sys.version_info[:3]


class BasicConfig(metaclass=ABCMeta):
def __init__(self, hook=None, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions cereja/geolinear/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .point import Point
from .utils import rotation_matrix_3d, rotation_matrix_2d
72 changes: 72 additions & 0 deletions cereja/geolinear/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import math

__all__ = ["Rotation"]

from typing import Union


class Rotation:
__axis_map = {0: "x", 1: "y", 2: "z"}

def __init__(self, degrees: Union[int, float]):
self._angle_rad = math.radians(degrees)
self._cos_angle = math.cos(self._angle_rad)
self._sin_angle = math.sin(self._angle_rad)
self._rot_x = [[1, 0, 0], [0, self._cos_angle, -self._sin_angle], [0, self._sin_angle, self._cos_angle]]
self._rot_y = [[self._cos_angle, 0, self._sin_angle], [0, 1, 0], [-self._sin_angle, 0, self._cos_angle]]
self._rot_z = [[self._cos_angle, -self._sin_angle, 0], [self._sin_angle, self._cos_angle, 0], [0, 0, 1]]
self._rot_2d = [[self._cos_angle, -self._sin_angle], [self._sin_angle, self._cos_angle]]

@property
def rot_3d_x(self):
return self._rot_x

@property
def rot_3d_y(self):
return self._rot_y

@property
def rot_3d_z(self):
return self._rot_z

@property
def rot_2d(self):
return self._rot_2d

def rotate_point(self, point, axis=None):
if len(point) == 2:
rx = self._rot_2d[0][0] * point[0] + self._rot_2d[0][1] * point[1]
ry = self._rot_2d[1][0] * point[0] + self._rot_2d[1][1] * point[1]
return [rx, ry]
assert len(point) == 3, ValueError(f"{point} isn't 3D point.")
if axis is None:
axis = 0
else:
assert (axis.lower() if isinstance(axis, str) else axis) in ("x", "y", "z", 0, 1, 2), ValueError(
"Invalid axis: choose 'x', 'y', 'z' or 0, 1, 2")
axis = axis if isinstance(axis, str) else self.__axis_map[int(axis)]
if axis == "x":
rot = self.rot_3d_x
elif axis == "y":
rot = self.rot_3d_y
else:
rot = self.rot_3d_z

rx = (rot[0][0] * point[0]) + (rot[0][1] * point[1]) + (rot[0][2] * point[2])
ry = (rot[1][0] * point[0]) + (rot[1][1] * point[1]) + (rot[1][2] * point[2])
rz = (rot[2][0] * point[0]) + (rot[2][1] * point[1]) + (rot[2][2] * point[2])

return [rx, ry, rz]

def rotate(self, val, axis=None):
from cereja import get_shape, reshape, flatten, shape_is_ok

if not shape_is_ok(val):
raise ValueError("value isn't valid to rotate.")

shape = get_shape(val)
if len(shape) == 1:
return self.rotate_point(val, axis=axis)

return reshape(list(map(lambda point: self.rotate_point(point, axis=axis), flatten(val, depth=len(shape) - 2))),
shape)
6 changes: 4 additions & 2 deletions cereja/system/_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from datetime import datetime
from typing import List, Union

from ..array import group_items_in_batches

from pathlib import Path as Path_
import glob
from ..utils.decorators import on_except
Expand Down Expand Up @@ -96,6 +96,7 @@ def group_path_from_dir(
:param key_sort_function: function order items
:return:
"""
from ..utils import chunk

if "." not in ext_file:
ext_file = "." + ext_file
Expand All @@ -110,7 +111,8 @@ def group_path_from_dir(
**key_sort_function,
)

batches = group_items_in_batches(items=paths, items_per_batch=num_items_on_tuple)

batches = chunk(paths, batch_size=num_items_on_tuple)

if to_random:
random.shuffle(batches)
Expand Down
15 changes: 9 additions & 6 deletions cereja/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,13 +1172,16 @@ def is_iterable(obj: Any) -> bool:
"""
Return whether an object is iterable or not.
:param obj: Any object for check
This function checks if the object has an __iter__ method or supports
sequence-like indexing via __getitem__.
Parameters:
obj (Any): Any object to check for iterability.
Returns:
bool: True if the object is iterable, False otherwise.
"""
try:
iter(obj)
except TypeError:
return False
return True
return hasattr(obj, '__iter__') or hasattr(obj, '__getitem__')


def has_length(seq):
Expand Down
1 change: 0 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import logging

from cereja.array import (
group_items_in_batches,
remove_duplicate_items,
flatten,
array_gen,
Expand Down

0 comments on commit 5ccc1bd

Please sign in to comment.