Skip to content

Commit

Permalink
improve tests and introduce a class for the center
Browse files Browse the repository at this point in the history
  • Loading branch information
tschm committed Jan 15, 2025
1 parent 72fcf7b commit 1e24aa1
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/cvx/ball/solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import cvxpy as cp
import numpy as np

from .utils.circle import Circle
from .utils.circle import Center, Circle


def min_circle_cvx(points, **kwargs):
Expand All @@ -22,4 +22,4 @@ def min_circle_cvx(points, **kwargs):
problem = cp.Problem(objective=objective, constraints=constraints)
problem.solve(**kwargs)

return Circle(radius=float(r.value), center=x.value)
return Circle(radius=float(r.value), center=Center(x.value))
27 changes: 23 additions & 4 deletions src/cvx/ball/utils/circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,30 @@
import plotly.graph_objects as go


@dataclass(frozen=True)
class Center:
array: np.ndarray

def __getitem__(self, item):
return self.array[item]

def scatter(self, **kwargs):
return go.Scatter(
x=[self[0]],
y=[self[1]],
mode="markers",
marker=dict(symbol="x", size=8, color="blue"),
name=f"Center(x = {self[0]:.2f}, y = {self[1]:.2f})",
**kwargs,
)


@dataclass(frozen=True)
class Circle:
center: np.ndarray
center: Center
radius: float

def scatter(self, num=100, color="red"):
def scatter(self, num=800, color="red", **kwargs):
t = np.linspace(0, 2 * np.pi, num=num)
radius = self.radius
circle_x = self.center[0] + radius * np.cos(t)
Expand All @@ -19,6 +37,7 @@ def scatter(self, num=100, color="red"):
x=circle_x,
y=circle_y,
mode="lines",
line=dict(color=color, width=2),
name=f"Circle(r = {self.radius})",
line=dict(color=color, width=1),
name=f"Circle(r = {self.radius:.2f})",
**kwargs,
)
5 changes: 3 additions & 2 deletions src/cvx/ball/utils/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
class Cloud:
points: np.ndarray

def scatter(self, size=10):
def scatter(self, size=5, **kwargs):
return go.Scatter(
x=self.points[:, 0],
y=self.points[:, 1],
mode="markers",
marker=dict(symbol="x", size=size, color="blue"),
marker=dict(symbol="circle", size=size, color="black"),
**kwargs,
)
21 changes: 18 additions & 3 deletions src/tests/test_solver.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import numpy as np
import pytest

from cvx.ball.solver import min_circle_cvx
from cvx.ball.utils.cloud import Cloud
from cvx.ball.utils.figure import create_figure


def test_random():
p = np.array([[2.0, 4.0], [0, 0], [2.5, 2.0]])
cloud = Cloud(p)
p = np.array([[2.0, 4.0], [0.0, 0.0], [2.5, 2.0]])
circle = min_circle_cvx(p, solver="CLARABEL")

fig = create_figure()
fig.add_trace(circle.scatter())
fig.add_trace(cloud.scatter())
fig.add_trace(Cloud(p).scatter())

assert circle.radius == pytest.approx(2.2360679626271796, 1e-6)
assert circle.center.array == pytest.approx([1.0, 2.0], 1e-4)
# fig.show()


def test_graph():
p = np.random.randn(50, 2)
circle = min_circle_cvx(p, solver="CLARABEL")

fig = create_figure()
fig.add_trace(circle.scatter())
fig.add_trace(circle.center.scatter())
fig.add_trace(Cloud(p).scatter(name="Cloud"))

fig.update_layout(xaxis_range=[-3, 3], yaxis_range=[-3, 3])
fig.show()

0 comments on commit 1e24aa1

Please sign in to comment.