# -*- coding: utf-8 -*-
# Copyright (C) 2016-2020 by Cristina Garcia-Cardona <cgarciac@lanl.gov>
# Brendt Wohlberg <brendt@ieee.org>
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.
"""Base classes for PGM algorithms."""
from __future__ import division, print_function
from builtins import range
import copy
import numpy as np
from sporco.cdict import ConstrainedDict
from sporco.common import IterativeSolver, solve_status_str
from sporco.fft import rfftn, irfftn
from sporco.array import transpose_ntpl_list
from sporco.util import Timer
from .backtrack import BacktrackRobust
from .momentum import MomentumNesterov
from .stepsize import StepSizePolicyBB
__author__ = """Cristina Garcia-Cardona <cgarciac@lanl.gov>"""
__all__ = ['PGM', 'PGMDFT']
[docs]
class PGM(IterativeSolver):
r"""Base class for Proximal Gradient Method (PGM) algorithms
(see for example Ch. 10 of :cite:`beck-2017-first` and
Sec. 4.2 and 4.3 of :cite:`parikh-2014-proximal`). Algorithms
such as FISTA :cite:`beck-2009-fast` and a robust variant of
FISTA :cite:`florea-2017-robust` are also supported.
Solve optimisation problems of the form
.. math::
\mathrm{argmin}_{\mathbf{x}} \; f(\mathbf{x}) + g(\mathbf{x}) \;\;,
where :math:`f, g` are convex functions and :math:`f` is smooth.
This class is intended to be a base class of other classes that
specialise to specific optimisation problems.
After termination of the :meth:`solve` method, attribute
:attr:`itstat` is a list of tuples representing statistics of each
iteration. The default fields of the named tuple
``IterationStats`` are:
``Iter`` : Iteration number
``ObjFun`` : Objective function value
``FVal`` : Value of smooth objective function component :math:`f`
``GVal`` : Value of objective function component :math:`g`
``F_Btrack`` : Value of objective function :math:`f + g`
(see Sec. 2.2 of :cite:`beck-2009-fast`) when backtracking
``Q_Btrack`` : Value of Quadratic approximation :math:`Q_L`
(see Sec. 2.3 of :cite:`beck-2009-fast`) when backtracking
``IterBtrack`` : Number of iterations in backtracking
``Rsdl`` : Residual
``L`` : Inverse of gradient step parameter
``Time`` : Cumulative run time
"""
[docs]
class Options(ConstrainedDict):
r"""PGM algorithm options.
Options:
``FastSolve`` : Flag determining whether non-essential
computation is skipped. When ``FastSolve`` is ``True`` and
``Verbose`` is ``False``, the functional value and related
iteration statistics are not computed. If ``FastSolve`` is
``True`` residuals are also not calculated, in which case the
residual-based stopping method is also disabled, with the
number of iterations determined only by ``MaxMainIter``.
``Verbose`` : Flag determining whether iteration status is
displayed.
``StatusHeader`` : Flag determining whether status header and
separator are displayed.
``DataType`` : Specify data type for solution variables,
e.g. ``np.float32``.
``X0`` : Initial value for X variable.
``Callback`` : Callback function to be called at the end of
every iteration.
``MaxMainIter`` : Maximum main iterations.
``IterTimer`` : Label of the timer to use for iteration times.
``RelStopTol`` : Relative convergence tolerance for fixed point
residual (see Sec. 4.3 of :cite:`liu-2018-first`).
``L`` : Inverse of gradient step parameter :math:`L`.
``AutoStop`` : Options for adaptive stopping strategy (fixed
point residual, see Sec. 4.3 of :cite:`liu-2018-first`).
``Enabled`` : Flag determining whether the adaptive stopping
relative parameter strategy is enabled.
``Tau0`` : numerator in adaptive criterion
(:math:`\tau_0` in :cite:`liu-2018-first`).
``Monotone`` : Flag determining whether a monotone PGM version
from :cite:`beck-2009-tv` is used. Default is False.
``Momentum`` : Momentum coefficient adaptation object. Standard
options are Nesterov :cite:`beck-2009-fast`
(:class:`.MomentumNesterov`), Linear
:cite:`chambolle-2015-convergence`
(:class:`.MomentumLinear`), and GenLinear
:cite:`rodriguez-2019-convergence`
(:class:`.MomentumGenLinear`), but a custom class derived
from :class:`.MomentumBase` may also be specified. Default
is :class:`.MomentumNesterov`.
``StepSizePolicy`` : non-iterative L adaptation object.
Standard options are Cauchy :cite:`yuan-2008-stepsize`
Sec. 3 (:class:`.StepSizePolicyCauchy`), and Barzilai-Borwein
:cite:`barzilai-1988-stepsize`
(:class:`.StepSizePolicyBB`), but a custom class derived
from :class:`.StepSizePolicyBase` may also be specified.
Default is None, no non-iterative L adaptation. Note that in
case that both step size and Backtrack strategies are enabled
only Backtrack will be used.
``Backtrack`` : PGM backtracking options. Options are Standard
:cite:`beck-2009-fast` (:class:`.BacktrackStandard`) and
Robust :cite:`florea-2017-robust` (:class:`.BacktrackRobust`),
but a custom class derived from :class:`.BacktrackBase` may
also be specified. Default is None, no backtracking. Note that
in case that both step size and Backtrack strategies
are enabled only Backtrack will be used.
"""
defaults = {'FastSolve': False, 'Verbose': False,
'StatusHeader': True, 'DataType': None,
'X0': None, 'Callback': None,
'MaxMainIter': 1000, 'IterTimer': 'solve',
'RelStopTol': 1e-3, 'L': None,
'AutoStop': {'Enabled': False, 'Tau0': 1e-2},
'Monotone': False,
'Momentum': MomentumNesterov(),
'StepSizePolicy': None,
'Backtrack': None}
def __init__(self, opt=None):
"""
Parameters
----------
opt : dict or None, optional (default None)
PGM algorithm options
"""
if opt is None:
opt = {}
ConstrainedDict.__init__(self, opt)
fwiter = 4
"""Field width for iteration count display column"""
fpothr = 2
"""Field precision for other display columns"""
itstat_fields_objfn = ('ObjFun', 'FVal', 'GVal')
"""Fields in IterationStats associated with the objective function;
see :meth:`eval_objfun`"""
itstat_fields_alg = ('Rsdl', 'F_Btrack', 'Q_Btrack', 'IterBTrack', 'L')
"""Fields in IterationStats associated with the specific solver
algorithm"""
itstat_fields_extra = ()
"""Non-standard fields in IterationStats; see :meth:`itstat_extra`"""
hdrtxt_objfn = ('Fnc', 'f', 'g')
"""Display column headers associated with the objective function;
see :meth:`eval_objfun`"""
hdrval_objfun = {'Fnc': 'ObjFun', 'f': 'FVal', 'g': 'GVal'}
"""Dictionary mapping display column headers in :attr:`hdrtxt_objfn`
to IterationStats entries"""
def __new__(cls, *args, **kwargs):
"""Create a PGM object and start its initialisation timer."""
instance = super(PGM, cls).__new__(cls)
instance.timer = Timer(['init', 'solve', 'solve_wo_func',
'solve_wo_rsdl', 'solve_wo_btrack'])
instance.timer.start('init')
return instance
def __init__(self, xshape, dtype, opt=None):
r"""
Parameters
----------
xshape : tuple of ints
Shape of working variable X
dtype : data-type
Data type for working variables (overridden by 'DataType' option)
opt : :class:`PGM.Options` object
Algorithm options
"""
if opt is None:
opt = PGM.Options()
if not isinstance(opt, PGM.Options):
raise TypeError("Parameter opt must be an instance of "
"PGM.Options")
self.opt = opt
# DataType option overrides data type inferred from __init__
# parameters of derived class
self.set_dtype(opt, dtype)
# Initialise attributes representing step parameter and other
# parameters
self.set_attr('L', opt['L'], dval=1.0, dtype=self.dtype)
# Configure policy for step size
# Step size policy is turned off if Backtrack is enabled
self.stepsizepolicy = self.opt['StepSizePolicy']
if self.opt['Backtrack'] is not None:
self.stepsizepolicy = None
# Configure Momentum coefficients
self.momentum = self.opt['Momentum']
# If using adaptative stop criterion, set tau0 parameter
if self.opt['AutoStop', 'Enabled']:
self.tau0 = self.opt['AutoStop', 'Tau0']
# Initialise working variable X
if self.opt['X0'] is None:
self.X = self.xinit(xshape)
else:
self.X = self.opt['X0'].astype(self.dtype, copy=True)
# Default values for variables created only if Backtrack is enabled
self.F = None
self.Q = None
self.iterBTrack = None
self.backtrack = self.opt['Backtrack']
self.Y = None
self.itstat = []
self.k = 0
self.t = 1
[docs]
def xinit(self, xshape):
"""Return initialiser for working variable X."""
return np.zeros(xshape, dtype=self.dtype)
[docs]
def solve(self):
"""Start (or re-start) optimisation. This method implements the
framework for the iterations of a PGM algorithm. There is
sufficient flexibility in overriding the component methods that
it calls that it is usually not necessary to override this method
in derived clases.
If option ``Verbose`` is ``True``, the progress of the
optimisation is displayed at every iteration. At termination
of this method, attribute :attr:`itstat` is a list of tuples
representing statistics of each iteration, unless option
``FastSolve`` is ``True`` and option ``Verbose`` is ``False``.
Attribute :attr:`timer` is an instance of :class:`.util.Timer`
that provides the following labelled timers:
``init``: Time taken for object initialisation by
:meth:`__init__`
``solve``: Total time taken by call(s) to :meth:`solve`
``solve_wo_func``: Total time taken by call(s) to
:meth:`solve`, excluding time taken to compute functional
value and related iteration statistics
``solve_wo_rsdl`` : Total time taken by call(s) to
:meth:`solve`, excluding time taken to compute functional
value and related iteration statistics as well as time take
to compute residuals
``solve_wo_btrack`` : Total time taken by call(s) to
:meth:`solve`, excluding time taken to compute functional
value and related iteration statistics as well as time take
to compute residuals and implemented ``Backtrack`` mechanism
"""
# Open status display
fmtstr, nsep = self.display_start()
# Start solve timer
self.timer.start(['solve', 'solve_wo_func', 'solve_wo_rsdl',
'solve_wo_btrack'])
# Main optimisation iterations
for self.k in range(self.k, self.k + self.opt['MaxMainIter']):
# Update record of X and Y from previous iteration
self.on_iteration_start()
# Compute backtracking
if self.opt['Backtrack'] is not None and self.k >= 0:
self.timer.stop('solve_wo_btrack')
# Compute backtracking
self.backtrack.update(self)
self.timer.start('solve_wo_btrack')
else:
# Compute just proximal step
self.xstep()
# Update by combining previous iterates
self.ystep()
# Compute residuals and stopping thresholds
self.timer.stop(['solve_wo_rsdl', 'solve_wo_btrack'])
if not self.opt['FastSolve']:
frcxd, adapt_tol = self.compute_residuals()
self.timer.start('solve_wo_rsdl')
# Compute and record other iteration statistics and
# display iteration stats if Verbose option enabled
self.timer.stop(['solve_wo_func', 'solve_wo_rsdl',
'solve_wo_btrack'])
if not self.opt['FastSolve']:
itst = self.iteration_stats(self.k, frcxd)
self.itstat.append(itst)
self.display_status(fmtstr, itst)
self.timer.start(['solve_wo_func', 'solve_wo_rsdl',
'solve_wo_btrack'])
# Call callback function if defined
if self.opt['Callback'] is not None:
if self.opt['Callback'](self):
break
# Stop if residual-based stopping tolerances reached
if not self.opt['FastSolve']:
if frcxd < adapt_tol:
break
# Increment iteration count
self.k += 1
# Record solve time
self.timer.stop(['solve', 'solve_wo_func', 'solve_wo_rsdl',
'solve_wo_btrack'])
# Print final separator string if Verbose option enabled
self.display_end(nsep)
return self.getmin()
[docs]
def getmin(self):
"""Get minimiser after optimisation."""
return self.X
[docs]
def xstep(self, grad=None):
"""Compute proximal update (gradient descent + regularization).
Optionally, a monotone PGM version from :cite:`beck-2009-tv`
is available.
"""
if grad is None:
grad = self.grad_f()
if self.stepsizepolicy is not None:
if self.k > 1:
self.L = self.stepsizepolicy.update(self, grad)
if isinstance(self.stepsizepolicy, StepSizePolicyBB):
# BB variants are two-point methods
self.stepsizepolicy.store_prev_state(self.X, grad)
V = self.Y - (1. / self.L) * grad
self.X = self.prox_g(V)
if self.opt['Monotone'] and self.k > 0:
self.ZZ = self.X.copy()
self.objfn = self.eval_objfn()
if self.objfn_prev[0] < self.objfn[0]:
# If increment on objective function
# revert to previous iterate
self.X = self.Xprv.copy()
self.objfn = self.objfn_prev
return grad
[docs]
def ystep(self):
"""Build next update by a smart combination of previous updates
(standard PGM :cite:`beck-2009-fast`). Optionally, a monotone
PGM version from :cite:`beck-2009-tv` is available.
"""
# Update t step
tprv = self.t
self.t = self.momentum.update(self.var_momentum())
# Update Y
if self.opt['Monotone'] and self.k > 0:
self.Y = self.X + (tprv / self.t) * (self.ZZ - self.X) \
+ ((tprv - 1.) / self.t) * (self.X - self.Xprv)
else:
self.Y = self.X + ((tprv - 1.) / self.t) * (self.X - self.Xprv)
[docs]
def eval_linear_approx(self, Dxy, gradY):
r"""Compute term :math:`\langle \nabla f(\mathbf{y}), \mathbf{x}
- \mathbf{y} \rangle` that is part of the quadratic function
:math:`Q_L` used for backtracking.
"""
return np.sum(Dxy * gradY)
[docs]
def grad_f(self, V):
"""Compute gradient of :math:`f` at V.
Overriding this method is required.
"""
raise NotImplementedError()
[docs]
def prox_g(self, V):
"""Compute proximal operator of :math:`g`.
Overriding this method is required.
"""
raise NotImplementedError()
[docs]
def hessian_f(self, V):
"""Compute Hessian of :math:`f` and apply to V.
Overriding this method is required.
"""
raise NotImplementedError()
[docs]
def on_iteration_start(self):
"""Store previous X and Y states."""
self.Xprv = self.X.copy()
if (not self.opt['FastSolve'] or isinstance(self.backtrack,
BacktrackRobust)):
self.Yprv = self.Y.copy()
if self.opt['Monotone']:
if self.k == 0:
self.objfn = self.eval_objfn()
self.objfn_prev = self.objfn
[docs]
def eval_Dxy(self):
"""Evaluate difference of state and auxiliary state updates."""
return self.X - self.Y
[docs]
def compute_residuals(self):
"""Compute residuals and stopping thresholds."""
r = self.rsdl()
adapt_tol = self.opt['RelStopTol']
if self.opt['AutoStop', 'Enabled']:
adapt_tol = self.tau0 / (1. + self.k)
return r, adapt_tol
[docs]
@classmethod
def hdrtxt(cls):
"""Construct tuple of status display column title."""
return ('Itn',) + cls.hdrtxt_objfn + ('Rsdl', 'F', 'Q', 'It_Bt', 'L')
[docs]
@classmethod
def hdrval(cls):
"""Construct dictionary mapping display column title to
IterationStats entries.
"""
hdr = {'Itn': 'Iter'}
hdr.update(cls.hdrval_objfun)
hdr.update({'Rsdl': 'Rsdl', 'F': 'F_Btrack', 'Q': 'Q_Btrack',
'It_Bt': 'IterBTrack', 'L': 'L'})
return hdr
[docs]
def iteration_stats(self, k, frcxd):
"""Construct iteration stats record tuple."""
tk = self.timer.elapsed(self.opt['IterTimer'])
if self.opt['Monotone']:
tpl = (k,) + self.objfn \
+ (frcxd, self.F, self.Q, self.iterBTrack, self.L) \
+ self.itstat_extra() + (tk,)
else:
tpl = (k,) + self.eval_objfn() \
+ (frcxd, self.F, self.Q, self.iterBTrack, self.L) \
+ self.itstat_extra() + (tk,)
return type(self).IterationStats(*tpl)
[docs]
def eval_objfn(self):
"""Compute components of objective function as well as total
contribution to objective function.
"""
fval = self.obfn_f(self.X)
gval = self.obfn_g(self.X)
obj = fval + gval
return (obj, fval, gval)
[docs]
def getitstat(self):
"""Get iteration stats as named tuple of arrays instead of
array of named tuples.
"""
return transpose_ntpl_list(self.itstat)
[docs]
def display_start(self):
"""Set up status display if option selected. NB: this method
assumes that the first entry is the iteration count and the
last is the L value.
"""
if self.opt['Verbose']:
# If backtracking option enabled F, Q, itBT, L are
# included in iteration status
if self.opt['Backtrack'] is not None:
hdrtxt = type(self).hdrtxt()
else:
hdrtxt = type(self).hdrtxt()[0:-4]
# Call utility function to construct status display formatting
hdrstr, fmtstr, nsep = solve_status_str(
hdrtxt, fmtmap={'It_Bt': '%5d'}, fwdth0=type(self).fwiter,
fprec=type(self).fpothr)
# Print header and separator strings
if self.opt['StatusHeader']:
print(hdrstr)
print("-" * nsep)
else:
fmtstr, nsep = '', 0
return fmtstr, nsep
[docs]
def display_status(self, fmtstr, itst):
"""Display current iteration status as selection of fields from
iteration stats tuple.
"""
if self.opt['Verbose']:
hdrtxt = type(self).hdrtxt()
hdrval = type(self).hdrval()
itdsp = tuple([getattr(itst, hdrval[col]) for col in hdrtxt])
if self.opt['Backtrack'] is None:
itdsp = itdsp[0:-4]
print(fmtstr % itdsp)
[docs]
def display_end(self, nsep):
"""Terminate status display if option selected."""
if self.opt['Verbose'] and self.opt['StatusHeader']:
print("-" * nsep)
[docs]
def var_x(self):
r"""Get :math:`\mathbf{x}` variable."""
return self.X
[docs]
def var_y(self, y=None):
r"""Get, or update and get, :math:`\mathbf{y}` variable."""
if y is not None:
self.Y = y
return self.Y
[docs]
def var_xprv(self):
r"""Get :math:`\mathbf{x}` variable of previous iteration."""
return self.Xprv
[docs]
def var_momentum(self):
"""Most momentum coefficient methods require iteration but Nesterov
requires current t."""
if isinstance(self.momentum, MomentumNesterov):
return self.t
return self.k
[docs]
def obfn_f(self, X):
r"""Compute :math:`f(\mathbf{x})` component of PGM objective
function.
Overriding this method is required (even if :meth:`eval_objfun`
is overriden, since this method is required for backtracking).
"""
raise NotImplementedError()
[docs]
def obfn_g(self, X):
r"""Compute :math:`g(\mathbf{x})` component of PGM objective
function.
Overriding this method is required if :meth:`eval_objfun`
is not overridden.
"""
raise NotImplementedError()
[docs]
def rsdl(self):
"""Compute fixed point residual (see Sec. 4.3 of
:cite:`liu-2018-first`)."""
if self.opt['Monotone'] and self.k > 0:
return np.linalg.norm((self.X - self.Y).ravel())
return np.linalg.norm((self.X - self.Yprv).ravel())
[docs]
class PGMDFT(PGM):
r"""
Base class for PGM algorithms with gradients and updates computed
in the frequency domain.
|
.. inheritance-diagram:: PGMDFT
:parts: 2
|
Solve optimisation problems of the form
.. math::
\mathrm{argmin}_{\mathbf{x}} \; f(\mathbf{x}) + g(\mathbf{x})
\;\;,
where :math:`f, g` are convex functions and :math:`f` is smooth.
This class specialises class PGM, but remains a base class for
other classes that specialise to specific optimisation problems.
"""
[docs]
class Options(PGM.Options):
"""PGMDFT algorithm options.
Options include all of those defined in :class:`PGM.Options`.
"""
defaults = copy.deepcopy(PGM.Options.defaults)
def __init__(self, opt=None):
"""
Parameters
----------
opt : dict or None, optional (default None)
PGMDFT algorithm options
"""
if opt is None:
opt = {}
PGM.Options.__init__(self, opt)
def __init__(self, xshape, Nv, axisN, dtype, opt=None):
"""
Parameters
----------
xshape : tuple of ints
Shape of working variable X (the primary variable)
Nv : tuple of ints
Shape of spatial indices of variable X (needed for DFT)
axisN : tuple of ints
Axis indices of spatial components of X (needed for DFT)
dtype : data-type
Data type for working variables
opt : :class:`PGMDFT.Options` object
Algorithm options
"""
if opt is None:
opt = PGMDFT.Options()
super(PGMDFT, self).__init__(xshape, dtype, opt)
self.Nv = Nv
self.axisN = axisN
[docs]
def xstep(self, gradf=None):
"""Compute proximal update (gradient descent + constraint).
Variables are mapped back and forth between input and
frequency domains. Optionally, a monotone PGM version from
:cite:`beck-2009-tv` is available.
"""
if gradf is None:
gradf = self.grad_f()
if self.stepsizepolicy is not None:
if self.k > 1:
self.L = self.stepsizepolicy.update(self, gradf)
if isinstance(self.stepsizepolicy, StepSizePolicyBB):
# BB variants are two-point methods
self.stepsizepolicy.store_prev_state(self.Xf, gradf)
self.Vf[:] = self.Yf - (1. / self.L) * gradf
V = irfftn(self.Vf, self.Nv, self.axisN)
self.X[:] = self.prox_g(V)
self.Xf = rfftn(self.X, None, self.axisN)
if self.opt['Monotone'] and self.k > 0:
self.ZZf = self.Xf.copy()
self.objfn = self.eval_objfn()
if self.objfn_prev[0] < self.objfn[0]:
# If increment on objective function
# revert to previous iterate
self.Xf = self.Xfprv.copy()
self.objfn = self.objfn_prev
return gradf
[docs]
def ystep(self):
"""Update auxiliary state by a smart combination of previous
updates in the frequency domain (standard PGM
:cite:`beck-2009-fast`). Optionally, a monotone PGM version
from :cite:`beck-2009-tv` is available.
"""
# Update t step
tprv = self.t
self.t = self.momentum.update(self.var_momentum())
# Update Y
if self.opt['Monotone'] and self.k > 0:
self.Yf = self.Xf + (tprv / self.t) * (self.ZZf - self.Xf) \
+ ((tprv - 1.) / self.t) * (self.Xf - self.Xfprv)
else:
self.Yf = self.Xf + ((tprv - 1.) / self.t) * (self.Xf - self.Xfprv)
[docs]
def on_iteration_start(self):
"""Store previous X and Y in frequency domain."""
self.Xfprv = self.Xf.copy()
if (not self.opt['FastSolve'] or isinstance(self.backtrack,
BacktrackRobust)):
self.Yfprv = self.Yf.copy()
if self.opt['Monotone']:
if self.k == 0:
self.objfn = self.eval_objfn()
self.objfn_prev = self.objfn
[docs]
def eval_Dxy(self):
"""Evaluate difference of state and auxiliary state in
frequency domain.
"""
return self.Xf - self.Yf
[docs]
def var_x(self):
r"""Get :math:`\mathbf{x}` variable in frequency domain."""
return self.Xf
[docs]
def var_y(self, y=None):
r"""Get, or update and get, :math:`\mathbf{y}` variable in
frequency domain."""
if y is not None:
self.Yf = y
return self.Yf
[docs]
def var_xprv(self):
r"""Get :math:`\mathbf{x}` variable of previous iteration in
frequency domain.
"""
return self.Xfprv
[docs]
def eval_linear_approx(self, Dxy, gradY):
r"""Compute term :math:`\langle \nabla f(\mathbf{y}),
\mathbf{x} - \mathbf{y} \rangle` (in frequency domain) that is
part of the quadratic function :math:`Q_L` used for
backtracking. Since this class computes the backtracking in
the DFT, it is important to preserve the DFT scaling.
"""
return np.sum(np.real(np.conj(Dxy) * gradY))