Skip to content

Commit

Permalink
Merge pull request #17 from cvxgrp/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
bmeyers authored Sep 26, 2023
2 parents 703a039 + dbc8a2f commit 37e4f01
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
18 changes: 14 additions & 4 deletions gfosd/components/basis_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
class Basis(GraphComponent):
def __init__(self, basis, penalty=None, *args, **kwargs):
self._basis = basis
# penalty can be None, an atom name (e.g. 'sum_square' or 'abs'), or PSD matrix (2D numpy array)
self._penalty = penalty
if isinstance(penalty, np.ndarray) or isinstance(penalty, sp.spmatrix):
self._penalty = 'matrix'
self._pmat = penalty
else:
self._ndim = None
super().__init__(*args, **kwargs)
self._has_helpers = True

Expand All @@ -40,19 +46,23 @@ def _make_B(self):
self._B = self._basis * -1

def _make_g(self, size):
if (self._penalty is None) or (self._penalty == 'sum_square'):
if (self._penalty is None) or (self._penalty == 'sum_square') or (self._penalty == 'matrix'):
g = []
else:
# typically 'abs', 'huber', or 'quantile'
g = [{'g': self._penalty,
'args': {'weight': self.weight},
'range': (0, size)}]
return g

def _make_P(self, size):
if (self._penalty is None) or (self._penalty != 'sum_square'):
P = sp.dok_matrix(2 * (size,))
else:
if self._penalty == 'matrix':
P = sp.dia_matrix(self._pmat)
P = P.power(2)
elif np.all(self._penalty == 'sum_square'):
P = self.weight * sp.eye(size)
else:
P = sp.dok_matrix(2 * (size,))
return P

class Periodic(Basis):
Expand Down
10 changes: 9 additions & 1 deletion gfosd/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,15 @@ def _solve_cvx(self, data, solver, **solver_kwargs):
return
objective = cvx.Minimize(cost)
cvx_prob = cvx.Problem(objective, constraints)

if solver == "CLARABEL":
if "eps_rel" in solver_kwargs.keys():
er = solver_kwargs["eps_rel"]
del solver_kwargs["eps_rel"]
solver_kwargs["tol_gap_rel"] = er
if "eps_abs" in solver_kwargs.keys():
ea = solver_kwargs["eps_abs"]
del solver_kwargs["eps_abs"]
solver_kwargs["tol_gap_abs"] = ea
cvx_prob.solve(solver=solver, **solver_kwargs)
self._cvx_obj = cvx_prob
self.objective_value = cvx_prob.value
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ cvxpy
matplotlib
scikit-learn
qss
clarabel

0 comments on commit 37e4f01

Please sign in to comment.