Source code for sporco.fista.ppp

# -*- coding: utf-8 -*-
# Copyright (C) 2019-2020 by Brendt Wohlberg <brendt@ieee.org>
#                            Ulugbek Kamilov <kamilov@wustl.edu>
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

"""Classes for FISTA variant of the Plug and Play Priors (PPP) algorithm."""

from __future__ import division, absolute_import, print_function

import numpy as np

from sporco.fista import fista


__author__ = """\n""".join(['Brendt Wohlberg <brendt@ieee.org>',
                            'Ulugbek Kamilov <kamilov@wustl.edu>'])



[docs]class GenericPPP(fista.FISTA): """Base class for Plug and Play Priors (PPP) FISTA solvers :cite:`kamilov-2017-plugandplay`.""" def __init__(self, xshape, opt=None): """ Parameters ---------- xshape : tuple of ints Shape of working variable X opt : :class:`GenericPPP.Options` object Algorithm options """ if opt is None: opt = GenericPPP.Options() # Set dtype attribute, default is np.float32 self.set_dtype(opt, np.dtype(np.float32)) super(GenericPPP, self).__init__(xshape, self.dtype, opt) self.store_prev() self.Y = self.X.copy() self.Yprv = np.zeros(self.Y.shape) itstat_fields_objfn = ('FVal',) hdrtxt_objfn = ('FVal',) hdrval_objfun = {'FVal': 'FVal'}
[docs] def eval_grad(self): """Compute the gradient of :math:`f`.""" return self.gradf(self.Y)
[docs] def eval_proxop(self, V): """Compute proximal operator of :math:`g`.""" return self.proxg(V, self.L)
[docs] def rsdl(self): """Compute fixed point residual.""" return np.linalg.norm((self.X - self.Yprv).ravel())
[docs] def eval_objfn(self): r"""Compute components of objective function. In this case the regularisation term is implicit so we can only evaluate the data fidelity term represented by the :math:`f(\cdot)` component of the functional to be minimised. """ return (self.f(self.X),)
[docs] def gradf(self, X): r"""Compute the gradient of :math:`f(\cdot)`. Overriding this method is required. """ raise NotImplementedError()
[docs] def proxg(self, X, L): r"""Compute the proximal operator of :math:`L^{-1} g(\cdot)`. Overriding this method is required. Note that this method should compute the proximal operator of :math:`L^{-1} g(\cdot)`, *not* the proximal operator of :math:`L g(\cdot)`. """ raise NotImplementedError()
[docs] def f(self, X): r"""Evauate the data fidelity term :math:`f(\mathbf{x})`. Overriding this method is required. """ raise NotImplementedError()
[docs]class PPP(GenericPPP): """Plug and Play Priors (PPP) solver :cite:`kamilov-2017-plugandplay` that can be used without the need to derive a new class.""" def __init__(self, xshape, f, gradf, proxg, opt=None): """ Parameters ---------- xshape : tuple of ints Shape of working variable X f : function Function evaluating the data fidelity term gradf : function Function computing the gradient of the data fidelity term proxg : function Function computing the proximal operator of the regularisation term opt : :class:`PPP.Options` object Algorithm options """ if opt is None: opt = PPP.Options() super(PPP, self).__init__(xshape, opt) self.f = f self.gradf = gradf self.proxg = proxg