diff --git a/src/cvx/ball/solver.py b/src/cvx/ball/solver.py index 06f7f27..0297486 100644 --- a/src/cvx/ball/solver.py +++ b/src/cvx/ball/solver.py @@ -1,7 +1,7 @@ import cvxpy as cp import numpy as np -from .utils.circle import Circle +from .utils.circle import Center, Circle def min_circle_cvx(points, **kwargs): @@ -22,4 +22,4 @@ def min_circle_cvx(points, **kwargs): problem = cp.Problem(objective=objective, constraints=constraints) problem.solve(**kwargs) - return Circle(radius=float(r.value), center=x.value) + return Circle(radius=float(r.value), center=Center(x.value)) diff --git a/src/cvx/ball/utils/circle.py b/src/cvx/ball/utils/circle.py index 6c36373..d4df413 100644 --- a/src/cvx/ball/utils/circle.py +++ b/src/cvx/ball/utils/circle.py @@ -4,12 +4,30 @@ import plotly.graph_objects as go +@dataclass(frozen=True) +class Center: + array: np.ndarray + + def __getitem__(self, item): + return self.array[item] + + def scatter(self, **kwargs): + return go.Scatter( + x=[self[0]], + y=[self[1]], + mode="markers", + marker=dict(symbol="x", size=8, color="blue"), + name=f"Center(x = {self[0]:.2f}, y = {self[1]:.2f})", + **kwargs, + ) + + @dataclass(frozen=True) class Circle: - center: np.ndarray + center: Center radius: float - def scatter(self, num=100, color="red"): + def scatter(self, num=800, color="red", **kwargs): t = np.linspace(0, 2 * np.pi, num=num) radius = self.radius circle_x = self.center[0] + radius * np.cos(t) @@ -19,6 +37,7 @@ def scatter(self, num=100, color="red"): x=circle_x, y=circle_y, mode="lines", - line=dict(color=color, width=2), - name=f"Circle(r = {self.radius})", + line=dict(color=color, width=1), + name=f"Circle(r = {self.radius:.2f})", + **kwargs, ) diff --git a/src/cvx/ball/utils/cloud.py b/src/cvx/ball/utils/cloud.py index 110b89e..8869bf2 100644 --- a/src/cvx/ball/utils/cloud.py +++ b/src/cvx/ball/utils/cloud.py @@ -8,10 +8,11 @@ class Cloud: points: np.ndarray - def scatter(self, size=10): + def scatter(self, size=5, **kwargs): return go.Scatter( x=self.points[:, 0], y=self.points[:, 1], mode="markers", - marker=dict(symbol="x", size=size, color="blue"), + marker=dict(symbol="circle", size=size, color="black"), + **kwargs, ) diff --git a/src/tests/test_solver.py b/src/tests/test_solver.py index 63b1785..f29112b 100644 --- a/src/tests/test_solver.py +++ b/src/tests/test_solver.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from cvx.ball.solver import min_circle_cvx from cvx.ball.utils.cloud import Cloud @@ -6,12 +7,26 @@ def test_random(): - p = np.array([[2.0, 4.0], [0, 0], [2.5, 2.0]]) - cloud = Cloud(p) + p = np.array([[2.0, 4.0], [0.0, 0.0], [2.5, 2.0]]) circle = min_circle_cvx(p, solver="CLARABEL") fig = create_figure() fig.add_trace(circle.scatter()) - fig.add_trace(cloud.scatter()) + fig.add_trace(Cloud(p).scatter()) + assert circle.radius == pytest.approx(2.2360679626271796, 1e-6) + assert circle.center.array == pytest.approx([1.0, 2.0], 1e-4) # fig.show() + + +def test_graph(): + p = np.random.randn(50, 2) + circle = min_circle_cvx(p, solver="CLARABEL") + + fig = create_figure() + fig.add_trace(circle.scatter()) + fig.add_trace(circle.center.scatter()) + fig.add_trace(Cloud(p).scatter(name="Cloud")) + + fig.update_layout(xaxis_range=[-3, 3], yaxis_range=[-3, 3]) + fig.show()