Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Priors/intermediate prior class #21

Draft
wants to merge 32 commits into
base: RDP/HessianVectorProduct
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
62a5521
RDP/accumulate_Hessian_times_input
Jun 7, 2021
eb07790
Fix issue with logic
robbietuk Jun 7, 2021
b7007b4
Make prior test method `test_gradient`
robbietuk Jun 8, 2021
96b6ec0
Adds test_Hessian method
robbietuk Jun 8, 2021
2f418bc
Test an array of different cases for the Hessian condition
robbietuk Jun 9, 2021
fa65422
improve Hessian documentation and info
robbietuk Jun 9, 2021
471019b
Merge branch 'RDP/TestHessian' into RDP/HessianVectorProduct
robbietuk Jun 9, 2021
3e4e5f4
Improve dot product
robbietuk Jun 11, 2021
a47fe11
Merge branch 'RDP/TestHessian' into Priors/Hessian
robbietuk Jun 11, 2021
101c4fc
Adds _is_convex variable to priors. default is false
robbietuk Jun 11, 2021
4bcafe5
Merge branch 'Priors/is_convex' into Priors/Hessian
robbietuk Jun 11, 2021
8187715
Reimplementation add_multiplication_with_approximate_Hessian
robbietuk Jun 11, 2021
410943e
Correct math in QP accumulate_Hessian_times_input
robbietuk Jun 11, 2021
948397f
Improve QP comments
robbietuk Jun 11, 2021
2c41bc7
Merge branch 'Priors/HessianQP' into Priors/Hessian
robbietuk Jun 15, 2021
a339ac0
[ci skip] Add two methods to RDP class :(off_)diagonal_second_derivative
robbietuk Jun 15, 2021
b6e6b54
Prior Hessian methods use second partial derivative functions (#14)
Jun 15, 2021
b3c8d6f
compute_Hessian method to all priors and test improvements in general
Jun 19, 2021
d767044
[ci skip] Add documentation to `test_gradient`
robbietuk Jun 19, 2021
c2ccf06
Increase RDP epsilon to be 0.1, increase so well above test `eps` value
robbietuk Jun 24, 2021
f60d97c
Add variables and set methods for toggling tests
robbietuk Jun 24, 2021
6669186
Cleanup and improve debugging checks
robbietuk Jun 24, 2021
511c69b
Remove _is_convex variable and better implement is_convex() method
robbietuk Jul 19, 2021
381b6c3
Update src/include/stir/recon_buildblock/GeneralisedPrior.h
Jul 19, 2021
610c784
Update src/include/stir/recon_buildblock/GeneralisedPrior.h
Jul 19, 2021
12f3eb6
Update src/include/stir/recon_buildblock/LogcoshPrior.h
Jul 19, 2021
1729522
Merge remote-tracking branch 'robbietuk/RDP/HessianVectorProduct' int…
robbietuk Jul 19, 2021
c68f4d9
[ci skip] Remove unnecessary documentation
robbietuk Jul 19, 2021
59aed0b
[ci skip] Rework and rename of second derivative methods for priors
robbietuk Jul 19, 2021
15aebe7
All prior Hessian methods return void values
robbietuk Jul 20, 2021
5fc8fea
Initial commit adding a new parent class
robbietuk Sep 30, 2021
30b3ee6
Buildable version of QP using GeneralisedConvexPrior as an intermediate
robbietuk Oct 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/include/stir/recon_buildblock/FilterRootPrior.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class FilterRootPrior: public
FilterRootPrior(shared_ptr<DataProcessor<DataT> >const&,
const float penalization_factor);

bool is_convex() const;

//! compute the value of the function
/*! \warning Generally there is no function associated to this prior,
so we just return 0 and write a warning the first time it's called.
Expand Down
73 changes: 73 additions & 0 deletions src/include/stir/recon_buildblock/GeneralisedConvexPrior.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//
//
/*
Copyright (C) 2000- 2007, Hammersmith Imanet Ltd
This file is part of STIR.

SPDX-License-Identifier: Apache-2.0

See STIR/LICENSE.txt for details
*/
/*!
\file
\ingroup priors
\brief Declaration of class stir::GeneralisedConvexPrior

\author Robert Twyman

*/

#ifndef __stir_recon_buildblock_GeneralisedConvexPrior_H__
#define __stir_recon_buildblock_GeneralisedConvexPrior_H__


#include "stir/recon_buildblock/GeneralisedPrior.h"

START_NAMESPACE_STIR

class Succeeded;

/*!
\ingroup priors
\brief
Make a brief
*/
template <typename DataT>

class GeneralisedConvexPrior:
virtual public GeneralisedPrior<DataT>

{
private:
typedef GeneralisedPrior<DataT> base_type;

public:
//! This computes a single row of the Hessian
/*! The method computes a row (i.e. at a densel/voxel, indicated by \c coords) of the Hessian at \c current_estimate.
Note that a row corresponds to an object of `DataT`.
The method (as implemented in derived classes) should store the result in \c prior_Hessian_for_single_densel.
*/
virtual void
compute_Hessian(DataT& prior_Hessian_for_single_densel,
const BasicCoordinate<3,int>& coords,
const DataT& current_image_estimate) const = 0;

// float penalisation_factor;




// virtual void
// actual_compute_Hessian(DataT& prior_Hessian_for_single_densel,
// const BasicCoordinate<3,int>& coords,
// const DataT& current_image_estimate) const;

int my_value = 1;

};

END_NAMESPACE_STIR

//#include "stir/recon_buildblock/GeneralisedPrior.inl"

#endif
21 changes: 19 additions & 2 deletions src/include/stir/recon_buildblock/GeneralisedPrior.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,27 @@ class GeneralisedPrior:
virtual void compute_gradient(DataT& prior_gradient,
const DataT &current_estimate) =0;

//! This computes a single row of the Hessian
/*! Default implementation just call error(). This function needs to be overridden by the
derived class.

The method computes a row (i.e. at a densel/voxel, indicated by \c coords) of the Hessian at \c current_estimate.
Note that a row corresponds to an object of `DataT`.
The method (as implemented in derived classes) should store the result in \c prior_Hessian_for_single_densel.
*/
virtual void
compute_Hessian(DataT& prior_Hessian_for_single_densel,
const BasicCoordinate<3,int>& coords,
const DataT& current_image_estimate) const;

//! This should compute the multiplication of the Hessian with a vector and add it to \a output
/*! Default implementation just call error(). This function needs to be overridden by the
derived class.
This method assumes that the hessian of the prior is 1 and hence the function quadratic.
Instead, accumulate_Hessian_times_input() should be used. This method remains for backwards comparability.
\warning The derived class should accumulate in \a output.
*/
virtual Succeeded
virtual void
add_multiplication_with_approximate_Hessian(DataT& output,
const DataT& input) const;

Expand All @@ -76,7 +89,7 @@ class GeneralisedPrior:
derived class.
\warning The derived class should accumulate in \a output.
*/
virtual Succeeded
virtual void
accumulate_Hessian_times_input(DataT& output,
const DataT& current_estimate,
const DataT& input) const;
Expand All @@ -89,6 +102,10 @@ class GeneralisedPrior:
virtual Succeeded
set_up(shared_ptr<const DataT> const& target_sptr);

//! Indicates if the prior is a smooth convex function
/*! If true, the prior is expected to have 0th, 1st and 2nd order behaviour implemented.*/
virtual bool is_convex() const = 0;

protected:
float penalisation_factor;
//! sets value for penalisation factor
Expand Down
36 changes: 19 additions & 17 deletions src/include/stir/recon_buildblock/LogcoshPrior.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class LogcoshPrior: public
parabolic_surrogate_curvature_depends_on_argument() const
{ return false; }

bool is_convex() const;

//! compute the value of the function
double
compute_value(const DiscretisedDensity<3,elemT> &current_image_estimate);
Expand All @@ -124,13 +126,13 @@ class LogcoshPrior: public
void parabolic_surrogate_curvature(DiscretisedDensity<3,elemT>& parabolic_surrogate_curvature,
const DiscretisedDensity<3,elemT> &current_image_estimate);

//! compute Hessian
void compute_Hessian(DiscretisedDensity<3,elemT>& prior_Hessian_for_single_densel,
const BasicCoordinate<3,int>& coords,
const DiscretisedDensity<3,elemT> &current_image_estimate);
virtual void
compute_Hessian(DiscretisedDensity<3,elemT>& prior_Hessian_for_single_densel,
const BasicCoordinate<3,int>& coords,
const DiscretisedDensity<3,elemT> &current_image_estimate) const;

//! Compute the multiplication of the hessian of the prior multiplied by the input.
virtual Succeeded accumulate_Hessian_times_input(DiscretisedDensity<3,elemT>& output,
//! Compute the multiplication of the hessian of the prior (at \c current_estimate) and the \c input.
virtual void accumulate_Hessian_times_input(DiscretisedDensity<3,elemT>& output,
const DiscretisedDensity<3,elemT>& current_estimate,
const DiscretisedDensity<3,elemT>& input) const;

Expand Down Expand Up @@ -224,21 +226,21 @@ class LogcoshPrior: public
{ return tanh(x)/x; }
}

//! The Hessian of log(cosh()) is sech^2(x) = (1/cosh(x))^2
//! The second partial derivatives of the LogCosh Prior
/*!
This function returns the hessian of the logcosh function
* @param d the difference between the ith and jth voxel.
* @param scalar is the logcosh scalar value controlling the priors transition between the quadratic and linear behaviour
* @return the second derivative of the log-cosh function
derivative_20 refers to the second derivative w.r.t. x_j only (i.e. diagonal elements of the Hessian)
derivative_11 refers to the second derivative w.r.t. x_j and x_k (i.e. off-diagonal elements of the Hessian)
* @param x_j is the target voxel.
* @param x_k is the voxel in the neighbourhood.
* @return the second order partial derivatives of the LogCosh Prior
*/
static inline float Hessian(const float d, const float scalar)
{
const float x = d * scalar;
return square((1/ cosh(x)));
}
//@{
elemT derivative_20(const elemT x_j, const elemT x_k) const;
elemT derivative_11(const elemT x_j, const elemT x_k) const;
//@}
};


END_NAMESPACE_STIR

#endif
#endif
2 changes: 2 additions & 0 deletions src/include/stir/recon_buildblock/PLSPrior.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class PLSPrior: public
/*! \todo set the anatomical image to zero if not defined */
virtual Succeeded set_up(shared_ptr<const DiscretisedDensity<3,elemT> > const& target_sptr);

bool is_convex() const;

//! compute the value of the function
double
compute_value(const DiscretisedDensity<3,elemT> &current_image_estimate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ START_NAMESPACE_STIR
*/
template <typename TargetT>
class PriorWithParabolicSurrogate:
public GeneralisedPrior<TargetT>
virtual public GeneralisedPrior<TargetT>
{
public:

Expand Down
39 changes: 28 additions & 11 deletions src/include/stir/recon_buildblock/QuadraticPrior.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "stir/RegisteredParsingObject.h"
#include "stir/recon_buildblock/PriorWithParabolicSurrogate.h"
#include "stir/recon_buildblock/GeneralisedConvexPrior.h"
#include "stir/Array.h"
#include "stir/DiscretisedDensity.h"
#include "stir/shared_ptr.h"
Expand Down Expand Up @@ -77,15 +78,16 @@ START_NAMESPACE_STIR
*/
template <typename elemT>
class QuadraticPrior: public
RegisteredParsingObject< QuadraticPrior<elemT>,
GeneralisedPrior<DiscretisedDensity<3,elemT> >,
PriorWithParabolicSurrogate<DiscretisedDensity<3,elemT> >
>
RegisteredParsingObject<
QuadraticPrior<elemT>,
GeneralisedPrior<DiscretisedDensity<3, elemT> >,
GeneralisedConvexPrior<DiscretisedDensity<3, elemT> > >,
PriorWithParabolicSurrogate<DiscretisedDensity<3, elemT> >
{
private:
typedef
RegisteredParsingObject< QuadraticPrior<elemT>,
GeneralisedPrior<DiscretisedDensity<3,elemT> >,
GeneralisedConvexPrior<DiscretisedDensity<3,elemT> >,
PriorWithParabolicSurrogate<DiscretisedDensity<3,elemT> > >
base_type;

Expand All @@ -103,6 +105,8 @@ class QuadraticPrior: public
parabolic_surrogate_curvature_depends_on_argument() const
{ return false; }

bool is_convex() const;

//! compute the value of the function
double
compute_value(const DiscretisedDensity<3,elemT> &current_image_estimate);
Expand All @@ -116,20 +120,20 @@ class QuadraticPrior: public
void parabolic_surrogate_curvature(DiscretisedDensity<3,elemT>& parabolic_surrogate_curvature,
const DiscretisedDensity<3,elemT> &current_image_estimate);

//! compute Hessian
void compute_Hessian(DiscretisedDensity<3,elemT>& prior_Hessian_for_single_densel,
const BasicCoordinate<3,int>& coords,
const DiscretisedDensity<3,elemT> &current_image_estimate);
virtual void
compute_Hessian(DiscretisedDensity<3,elemT>& prior_Hessian_for_single_densel,
const BasicCoordinate<3,int>& coords,
const DiscretisedDensity<3,elemT> &current_image_estimate) const;

//! Call accumulate_Hessian_times_input
virtual Succeeded
virtual void
add_multiplication_with_approximate_Hessian(DiscretisedDensity<3,elemT>& output,
const DiscretisedDensity<3,elemT>& input) const;

//! Compute the multiplication of the hessian of the prior multiplied by the input.
//! For the quadratic function, the hessian of the prior is 1.
//! Therefore this will return the weights multiplied by the input.
virtual Succeeded accumulate_Hessian_times_input(DiscretisedDensity<3,elemT>& output,
virtual void accumulate_Hessian_times_input(DiscretisedDensity<3,elemT>& output,
const DiscretisedDensity<3,elemT>& current_estimate,
const DiscretisedDensity<3,elemT>& input) const;

Expand Down Expand Up @@ -179,6 +183,19 @@ class QuadraticPrior: public
virtual bool post_processing();
private:
shared_ptr<const DiscretisedDensity<3,elemT> > kappa_ptr;

//! The second partial derivatives of the Quadratic Prior
/*!
derivative_20 refers to the second derivative w.r.t. x_j (i.e. diagonal elements of the Hessian)
derivative_11 refers to the second derivative w.r.t. x_j and x_k (i.e. off-diagonal elements of the Hessian)
* @param x_j is the target voxel.
* @param x_k is the voxel in the neighbourhood.
* @return the second order partial derivatives of the Quadratic Prior
*/
//@{
elemT derivative_20(const elemT x_j, const elemT x_k) const;
elemT derivative_11(const elemT x_j, const elemT x_k) const;
//@}
};


Expand Down
28 changes: 27 additions & 1 deletion src/include/stir/recon_buildblock/RelativeDifferencePrior.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,19 @@ class RelativeDifferencePrior: public
void compute_gradient(DiscretisedDensity<3,elemT>& prior_gradient,
const DiscretisedDensity<3,elemT> &current_image_estimate);

virtual void compute_Hessian(DiscretisedDensity<3,elemT>& prior_Hessian_for_single_densel,
const BasicCoordinate<3,int>& coords,
const DiscretisedDensity<3,elemT> &current_image_estimate) const;

virtual Succeeded
virtual void
add_multiplication_with_approximate_Hessian(DiscretisedDensity<3,elemT>& output,
const DiscretisedDensity<3,elemT>& input) const;

//! Compute the multiplication of the hessian of the prior multiplied by the input.
virtual void accumulate_Hessian_times_input(DiscretisedDensity<3,elemT>& output,
const DiscretisedDensity<3,elemT>& current_estimate,
const DiscretisedDensity<3,elemT>& input) const;

//! get the gamma value used in RDP
float get_gamma() const;
//! set the gamma value used in the RDP
Expand Down Expand Up @@ -160,6 +168,8 @@ class RelativeDifferencePrior: public
//! Has to be called before using this object
virtual Succeeded set_up(shared_ptr<DiscretisedDensity<3,elemT> > const& target_sptr);

bool is_convex() const;

protected:
//! Create variable gamma for Relative Difference Penalty
float gamma;
Expand Down Expand Up @@ -193,6 +203,22 @@ class RelativeDifferencePrior: public
virtual bool post_processing();
private:
shared_ptr<DiscretisedDensity<3,elemT> > kappa_ptr;

//! The second partial derivatives of the Relative Difference Prior
/*!
derivative_20 refers to the second derivative w.r.t. x_j only (i.e. diagonal elements of the Hessian)
derivative_11 refers to the second derivative w.r.t. x_j and x_k (i.e. off-diagonal elements of the Hessian)
See J. Nuyts, et al., 2002, Equation 7.
In the instance x_j, x_k and epsilon equal 0.0, these functions return 0.0 to prevent returning an undefined value
due to 0/0 computation. This is a "reasonable" solution to this issue.
* @param x_j is the target voxel.
* @param x_k is the voxel in the neighbourhood.
* @return the second order partial derivatives of the Relative Difference Prior
*/
//@{
elemT derivative_20(const elemT x_j, const elemT x_k) const;
elemT derivative_11(const elemT x_j, const elemT x_k) const;
//@}
};


Expand Down
7 changes: 3 additions & 4 deletions src/iterative/OSSPS/OSSPSReconstruction.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,10 @@ update_estimate(TargetT &current_image_estimate)
// For the quadratic prior, this is independent of the image (only on kappa's)
// And of course, it's also independent when there is no prior
// TODO by default, this should be off probably (to save time).
auto* parabolic_surrogate_prior = dynamic_cast<PriorWithParabolicSurrogate<TargetT>*>(this->get_prior_ptr());
const bool recompute_penalty_term_in_denominator =
!this->objective_function_sptr->prior_is_zero() &&
static_cast<PriorWithParabolicSurrogate<TargetT> const&>(*this->get_prior_ptr()).
parabolic_surrogate_curvature_depends_on_argument();
parabolic_surrogate_prior->parabolic_surrogate_curvature_depends_on_argument();
#ifndef PARALLEL
//CPUTimer subset_timer;
//subset_timer.start();
Expand Down Expand Up @@ -366,8 +366,7 @@ update_estimate(TargetT &current_image_estimate)
// avoid work (or crash) when penalty is 0
if (!this->objective_function_sptr->prior_is_zero())
{
static_cast<PriorWithParabolicSurrogate<TargetT>&>(*get_prior_ptr()).
parabolic_surrogate_curvature(*work_image_ptr, current_image_estimate);
parabolic_surrogate_prior->parabolic_surrogate_curvature(*work_image_ptr, current_image_estimate);
//*work_image_ptr *= 2;
//*work_image_ptr += *precomputed_denominator_ptr ;
std::transform(work_image_ptr->begin_all(), work_image_ptr->end_all(),
Expand Down
1 change: 1 addition & 0 deletions src/recon_buildblock/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ set(${dir_LIB_SOURCES}
BinNormalisationFromAttenuationImage.cxx
BinNormalisationSPECT.cxx
GeneralisedPrior.cxx
GeneralisedConvexPrior.cxx
ProjDataRebinning.cxx
FourierRebinning.cxx
PLSPrior.cxx
Expand Down
6 changes: 6 additions & 0 deletions src/recon_buildblock/FilterRootPrior.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ FilterRootPrior(shared_ptr<DataProcessor<DataT> >const& filter_sptr, float penal
this->penalisation_factor = penalisation_factor_v;
}

template <typename DataT>
bool FilterRootPrior<DataT>::
is_convex() const
{
return false;
}

template < class T>
static inline int
Expand Down
Loading