Skip to content

Commit

Permalink
Merge pull request #13 from jeertmans/refactor-args
Browse files Browse the repository at this point in the history
chore(lib): change seed to key, and add kwargs
  • Loading branch information
jeertmans authored Jul 25, 2023
2 parents 833cfd1 + 0425941 commit 181eea1
Show file tree
Hide file tree
Showing 7 changed files with 531 additions and 354 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ jobs:
run: |
poetry install --with test
- name: Get files to check
uses: jeertmans/filesfinder@v0.4.5
id: ff
- name: Install FilesFinder
uses: taiki-e/install-action@v2
with:
args: differt2d/**.py README.md
tool: filesfinder@latest

- name: Run ByExample
run: |
echo "${{ steps.ff.outputs.files }}" | xargs poetry run byexample -l python +timeout=60
ff "differt2d/**py" README.md | xargs poetry run byexample -l python +timeout=60
44 changes: 31 additions & 13 deletions differt2d/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Any, List, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -458,14 +458,15 @@ class FermatPath(Path):
"""

@classmethod
@partial(jit, static_argnames=("cls", "steps"))
@partial(jit, static_argnames=("cls", "steps", "optimizer"))
def from_tx_objects_rx(
cls,
tx: Point,
objects: List[Interactable],
rx: Point,
key: Optional[jax.random.PRNGKey] = None,
seed: int = 1234,
steps: int = 400,
**kwargs: Any,
) -> "FermatPath":
"""
Returns a path with minimal length.
Expand All @@ -474,8 +475,12 @@ def from_tx_objects_rx(
:param objects:
The list of objects to interact with (order is important).
:param rx: The receiving node.
:param seed: The random seed used to generate the start iteration.
:param steps: The number of iterations performed by the minimizer.
:param key: The random key to generate the initial guess.
:param seed: The random seed used to generate the start iteration,
only used if :python:`key is None`.
:param kwargs:
Keyword arguments to be passed to
:func:`minimize_many_random_uniform<differt2d.optimize.minimize_many_random_uniform>`.
:return: The resulting path of the FPT method.
:Examples:
Expand Down Expand Up @@ -514,8 +519,12 @@ def path_loss(cartesian_coords):

return _loss

key = jax.random.PRNGKey(seed)
theta, _ = minimize_many_random_uniform(fun=loss_fun, n=n_unknowns, key=key)
if key is None:
key = jax.random.PRNGKey(seed)

theta, _ = minimize_many_random_uniform(
fun=loss_fun, n=n_unknowns, key=key, **kwargs
)

points = parametric_to_cartesian(objects, theta, n, tx.point, rx.point)

Expand All @@ -529,14 +538,15 @@ class MinPath(Path):
"""

@classmethod
@partial(jit, static_argnames=("cls", "steps"))
@partial(jit, static_argnames=("cls", "steps", "optimizer"))
def from_tx_objects_rx(
cls,
tx: Point,
objects: List[Interactable],
rx: Point,
key: Optional[jax.random.PRNGKey] = None,
seed: int = 1234,
steps: int = 100,
**kwargs: Any,
) -> "MinPath":
"""
Returns a path that minimizes the sum of interactions.
Expand All @@ -545,8 +555,12 @@ def from_tx_objects_rx(
:param objects:
The list of objects to interact with (order is important).
:param rx: The receiving node.
:param seed: The random seed used to generate the start iteration.
:param steps: The number of iterations performed by the minimizer.
:param key: The random key to generate the initial guess.
:param seed: The random seed used to generate the start iteration,
only used if :python:`key is None`.
:param kwargs:
Keyword arguments to be passed to
:func:`minimize_many_random_uniform<differt2d.optimize.minimize_many_random_uniform>`.
:return: The resulting path of the MPT method.
:Examples:
Expand Down Expand Up @@ -581,8 +595,12 @@ def loss_fun(theta):

return _loss

key = jax.random.PRNGKey(seed)
theta, loss = minimize_many_random_uniform(fun=loss_fun, n=n_unknowns, key=key)
if key is None:
key = jax.random.PRNGKey(seed)

theta, loss = minimize_many_random_uniform(
fun=loss_fun, n=n_unknowns, key=key, **kwargs
)

points = parametric_to_cartesian(objects, theta, n, tx.point, rx.point)

Expand Down
5 changes: 2 additions & 3 deletions differt2d/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def minimize(
f_and_df = jax.value_and_grad(fun)
opt_state = optimizer.init(x0)

@jax.jit
def f(carry, x):
x, opt_state = carry
loss, grads = f_and_df(x)
Expand All @@ -91,7 +90,7 @@ def minimize_random_uniform(
with initial guess drawn randomly from a uniform distribution.
:param fun: The objective function to be minimized.
:param key: The random key to generate the initial guess.
:param key: The random key used to generate the initial guess.
:param n: The size of the random vector to generate.
:param kwargs:
Keyword arguments to be passed to :func:`minimize`.
Expand Down Expand Up @@ -128,7 +127,7 @@ def minimize_many_random_uniform(
and returns the best minimum out of the :code:`many` trials.
:param fun: The objective function to be minimized.
:param key: The random key to generate the initial guess.
:param key: The random key used to generate the initial guesses.
:param n: The size of the random vector to generate.
:param many:
How many times the minimization should be performed.
Expand Down
29 changes: 29 additions & 0 deletions differt2d/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,35 @@ class Scene(Plottable):
The list of objects in the scene.
"""

@classmethod
@partial(jax.jit, static_argnames=("cls", "n"))
def random_uniform_scene(cls, key: jax.random.KeyArray, n: int) -> "Scene":
"""
Generates a random scene with ``n`` walls,
drawing coordinates from a random distribution.
:Examples:
.. plot::
:include-source: true
import matplotlib.pyplot as plt
import jax
from differt2d.scene import Scene
ax = plt.gca()
key = jax.random.PRNGKey(1234)
scene = Scene.random_uniform_scene(key, 5)
_ = scene.plot(ax)
plt.show()
"""
points = jax.random.uniform(key, (2 * n + 2, 2))
tx = Point(point=points[+0, :])
rx = Point(point=points[-1, :])

walls = [Wall(points=points[2 * i : 2 * i + 2, :]) for i in range(1, n + 1)]
return cls(tx=tx, rx=rx, objects=walls)

@classmethod
def basic_scene(cls) -> "Scene":
"""
Expand Down
60 changes: 60 additions & 0 deletions examples/mpt_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from itertools import product

import jax
import optax
import pandas as pd

from differt2d.geometry import MinPath
from differt2d.scene import Scene


def min_path_tracing_loss(
key: jax.random.PRNGKey, size: int, optimizer: optax.GradientTransformation
):
key1, key2 = jax.random.split(key, 2)
scene = Scene.random_uniform_scene(key1, size)
return MinPath.from_tx_objects_rx(
scene.tx, scene.objects, scene.rx, key2, optimizer=optimizer
).loss


def main():
n = 1000
seed = 1234
key = jax.random.PRNGKey(seed)

sizes = [1, 2, 3, 4, 5]
optimizers = {
"adam": optax.adam,
"sgd": optax.sgd,
"adagrad": optax.adagrad,
"noisy_sgd": optax.noisy_sgd,
}
learning_rates = [1e-3, 1e-2, 1e-1, 1e-0]

parameters = product(sizes, optimizers.keys(), learning_rates)

results = {}

for size, optimizer, learning_rate in parameters:
print("size:", size)
key, key_to_use = jax.random.split(key)

opt = optimizers[optimizer](learning_rate)
losses = jax.vmap(min_path_tracing_loss, in_axes=(0, None, None), out_axes=0)(
jax.random.split(key_to_use, n), size, opt
)
results[(size, optimizer, learning_rate)] = losses

index = pd.MultiIndex.from_tuples(
results.keys(), names=["size", "optimizer", "learning_rate"]
)
df = pd.DataFrame(
data=results.values(),
index=index,
)
df.to_csv("optimizers.csv")


if __name__ == "__main__":
main()
Loading

0 comments on commit 181eea1

Please sign in to comment.