Skip to content

Commit

Permalink
Merge pull request #10 from cvxgrp/feature
Browse files Browse the repository at this point in the history
new features
  • Loading branch information
bmeyers authored Mar 16, 2023
2 parents e97fa04 + 6dfba1c commit 7755d54
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
2 changes: 1 addition & 1 deletion gfosd/components/sums.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(self, *args, **kwargs):
return

def _make_P(self, size):
return self.weight * sp.eye(size)
return self.weight * 2 * sp.eye(size) # note the (1/2) in canonical form!

class SumAbs(GraphComponent):
def __init__(self, *args, **kwargs):
Expand Down
43 changes: 36 additions & 7 deletions gfosd/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import scipy.sparse as sp
import qss
import cvxpy as cvx
from sklearn.model_selection import train_test_split
from osd.masking import Mask

class Problem():
Expand Down Expand Up @@ -82,18 +83,44 @@ def make_graph_form(self):
}
return out

def decompose(self, solver='qss', data=None, make_feasible=True, **kwargs):
if data is None:
data = self.make_graph_form()
def decompose(self, solver='qss', canonical_form=None, make_feasible=True, **kwargs):
if canonical_form is None:
canonical_form = self.make_graph_form()
if solver.lower() == 'qss':
result = self._solve_qss(data, **kwargs)
result = self._solve_qss(canonical_form, **kwargs)

else:
result = self._solve_cvx(data, solver, **kwargs)
result = self._solve_cvx(canonical_form, solver, **kwargs)
self.retrieve_result(result)
if solver.lower() == 'qss' and make_feasible:
self.make_feasible_qss()

def holdout_decompose(self, holdout_fraction=0.1, seed=None,
solver='qss', make_feasible=True, **kwargs):
use_set = self.mask.use_set
size = self.T * self.p
if self.p == 1:
known_ixs = np.arange(size)[use_set]
else:
known_ixs = np.arange(size)[use_set.ravel(order='F')]
train_ixs, test_ixs = train_test_split(
known_ixs, test_size=holdout_fraction, random_state=seed
)
hold_set = np.zeros(size, dtype=bool)
use_set = np.zeros(size, dtype=bool)
hold_set[test_ixs] = True
use_set[train_ixs] = True
if self.p != 1:
hold_set = hold_set.reshape((self.T, self.p), order='F')
use_set = use_set.reshape((self.T, self.p), order='F')
self.__old_mask = self.mask
self.mask = Mask(use_set)
self.decompose(solver=solver, make_feasible=make_feasible, **kwargs)
residual = (self.data[hold_set]
- np.sum(self.decomposition, axis=0)[hold_set])
self.mask = self.__old_mask
return residual, test_ixs

def _solve_qss(self, data, **solver_kwargs):
solver = qss.QSS(data)
objval, soln = solver.solve(**solver_kwargs)
Expand All @@ -107,9 +134,10 @@ def make_feasible_qss(self):
qss_data = self.make_graph_form()
new_solution = np.copy(self._qss_soln)
new_x1 = np.zeros_like(self.decomposition[0])
new_x1[~np.isnan(self.data)] = (
use_set = self.mask.use_set
new_x1[use_set] = (
self.data - np.sum(self.decomposition[1:], axis=0)
)[~np.isnan(self.data)]
)[use_set]
new_solution[:len(new_x1)] = new_x1
self.retrieve_result(new_solution)
self._qss_soln = new_solution
Expand Down Expand Up @@ -177,6 +205,7 @@ def retrieve_result(self, x_value):
else:
self.decomposition = None


def plot_decomposition(self, x_series=None, X_real=None, figsize=(10, 8),
label='estimated', exponentiate=False,
skip=None, **kwargs):
Expand Down

0 comments on commit 7755d54

Please sign in to comment.