Skip to content

Commit

Permalink
improves: improve split_sequence and add map_values utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jlsneto committed Sep 12, 2024
1 parent e4766bb commit 16136eb
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 25 deletions.
2 changes: 1 addition & 1 deletion cereja/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from ._requests import request
from . import scraping

VERSION = "2.0.1.final.0"
VERSION = "2.0.2.final.0"

__version__ = get_version_pep440_compliant(VERSION)

Expand Down
2 changes: 1 addition & 1 deletion cereja/config/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@ def get(self, item=None):
]

console_logger = logging.StreamHandler(sys.stdout)
logging.basicConfig(handlers=(console_logger,), level=logging.WARNING)
logging.basicConfig(handlers=(console_logger,), level=logging.INFO)
100 changes: 79 additions & 21 deletions cereja/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import sys
import types
import random
from typing import Any, Union, List, Tuple, Sequence, Iterable, Dict, MappingView, Optional, Callable, AnyStr
from typing import Any, Union, List, Tuple, Sequence, Iterable, Dict, MappingView, Optional, Callable, AnyStr, Iterator
import logging
import itertools
from copy import copy
Expand Down Expand Up @@ -94,7 +94,10 @@
"value_from_memory",
"str_gen",
"set_interval",
"SourceCodeAnalyzer"
"SourceCodeAnalyzer",
"map_values",
'decode_coordinates',
'encode_coordinates',
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -140,31 +143,51 @@ def split_sequence(seq: List[Any], is_break_fn: Callable) -> List[List[Any]]:
@return: list of subsequences
"""
if not isinstance(seq, list) or not seq:
raise ValueError("A sequência deve ser uma lista não vazia.")
raise ValueError("The sequence must be a non-empty list.")

if not callable(is_break_fn):
raise TypeError("is_break_fn deve ser uma função.")
raise TypeError("is_break_fn must be a function.")

# Inicializa com a primeira subsequência
sub_seqs = []
start_idx = 0
break_fn_arg_count = SourceCodeAnalyzer(is_break_fn).argument_count
for indx, val in enumerate(seq):
if len(seq) == 1:
return [seq]
for idx, is_break in enumerate(map_values(seq, is_break_fn)):
if is_break:
sub_seqs.append(seq[start_idx:idx + 1])
start_idx = idx + 1
if start_idx < len(seq):
sub_seqs.append(seq[start_idx:])
return sub_seqs


def map_values(obj: Union[dict, list, tuple, Iterator], fn: Callable) -> Union[dict, list, tuple, Iterator]:
fn_arg_count = SourceCodeAnalyzer(fn).argument_count
is_dict = isinstance(obj, dict)
if isinstance(obj, dict):
obj = obj.items()
_iter = iter(obj)
last = next(_iter, '__stop__')
if last == '__stop__':
return map(fn, obj)
idx = 0
while last != '__stop__':
_args = None
if indx + 1 == len(seq):
sub_seqs.append(seq[start_idx:])
_next = next(_iter, '__stop__')

if fn_arg_count == 1:
_args = (last,)
elif fn_arg_count == 2:
_args = (last, None if _next == '__stop__' else _next)
elif fn_arg_count == 3:
_args = (last, None if _next == '__stop__' else _next, idx)
if _next == '__stop__' and fn_arg_count >= 2:
if idx == 0:
yield fn(*_args)
break
if break_fn_arg_count == 1:
_args = (val,)
elif break_fn_arg_count == 2:
_args = (val, seq[indx + 1])
elif break_fn_arg_count == 3:
_args = (val, seq[indx + 1], indx)

if is_break_fn(*_args) if _args else is_break_fn():
sub_seqs.append(seq[start_idx:indx+1])
start_idx = indx+1
return sub_seqs
yield fn(*_args) if _args else last
last = _next
idx += 1


def chunk(
Expand Down Expand Up @@ -1518,7 +1541,7 @@ def get_zero_mask(number: int, max_len: int = 3) -> str:
return f"%0.{max_len}d" % number


def get_batch_strides(data, kernel_size, strides=1, fill_=True, take_index=False):
def get_batch_strides(data, kernel_size, strides=1, fill_=False, take_index=False):
"""
Returns batches of fixed window size (kernel_size) with a given stride
@param data: iterable
Expand Down Expand Up @@ -1565,3 +1588,38 @@ def set_interval(func: Callable, sec: float):
"""
from .decorators import on_elapsed
on_elapsed(sec, loop=True, use_threading=True)(func)()


def encode_coordinates(x: int, y: int):
"""
Encode the coordinates (x, y) into a single lParam value.
The encoding is done by shifting the y-coordinate 16 bits to the left and
then performing a bitwise OR with the x-coordinate.
Args:
x (int): The x-coordinate.
y (int): The y-coordinate.
Returns:
int: The encoded lParam value.
"""
return (y << 16) | x


def decode_coordinates(lparam: int):
"""
Decode the lParam value back into the original coordinates (x, y).
The decoding is done by extracting the lower 16 bits for the x-coordinate
and the upper 16 bits for the y-coordinate.
Args:
lparam (int): The encoded lParam value.
Returns:
tuple: A tuple containing the x and y coordinates.
"""
x = lparam & 0xFFFF
y = (lparam >> 16) & 0xFFFF
return x, y
6 changes: 4 additions & 2 deletions tests/testsutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_get_batch_strides(self):
]
for test_value, kernel_size, strides, expected in tests:
self.assertEqual(
list(utils.get_batch_strides(test_value, kernel_size, strides)),
list(utils.get_batch_strides(test_value, kernel_size, strides, fill_=True)),
expected,
)

Expand Down Expand Up @@ -623,7 +623,9 @@ def __init__(self, x, y):
self.y = y

def __eq__(self, other):
return self.x == other.x and self.y == other.y
if isinstance(other, Point):
return self.x == other.x and self.y == other.y
return False

seq = [Point(1, 2), Point(1, 3), Point(4, 4), Point(4, 8)]
is_same_x = lambda p1, p2: p1.x == p2.x
Expand Down

0 comments on commit 16136eb

Please sign in to comment.