Source code for sporco.pgm.momentum

# -*- coding: utf-8 -*-
# Copyright (C) 2016-2020 by Cristina Garcia-Cardona <cgarciac@lanl.gov>
#                            Brendt Wohlberg <brendt@ieee.org>
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

"""Momentum coefficient options for PGM algorithms"""

from __future__ import division, print_function

import numpy as np


__author__ = """Cristina Garcia-Cardona <cgarciac@lanl.gov>"""



[docs] class MomentumBase(object): """Base class for computing momentum coefficient for accelerated proximal gradient method. This class is intended to be a base class of other classes that specialise to specific momentum coefficient options. After termination of the :meth:`update` method the new momentum coefficient is returned. """ def __init__(self): super(MomentumBase, self).__init__()
[docs] def update(self): """Update momentum coefficient. Overriding this method is required. """ raise NotImplementedError()
[docs] class MomentumNesterov(MomentumBase): r"""Nesterov's momentum coefficient :cite:`beck-2009-fast` Applies the update .. math:: t^{(k+1)} = \frac{1}{2} \left( 1 + \sqrt{1 + 4 \; (t^{(k)})^2} \right) \;, with :math:`k` iteration. """ def __init__(self): super(MomentumNesterov, self).__init__()
[docs] def update(self, t): """Update momentum coefficient""" return 0.5 * float(1. + np.sqrt(1. + 4. * t**2))
[docs] class MomentumLinear(MomentumBase): r"""Linear momentum coefficient :cite:`chambolle-2015-convergence` Applies the update .. math:: t^{(k+1)} = \frac{k + b}{b} \;, with :math:`b` corresponding to a positive constant such that :math:`b \geq 2` and :math:`k` iteration. """ def __init__(self, b=2.): """ Parameters ---------- b : float Summand in numerator and factor in denominator of update. """ super(MomentumLinear, self).__init__() self.b = b
[docs] def update(self, k): """Update momentum coefficient""" return (k + self.b) / self.b
[docs] class MomentumGenLinear(MomentumBase): r"""Generalized linear momentum coefficient :cite:`rodriguez-2019-convergence` Applies the update .. math:: t^{(k+1)} = \frac{k + a}{b} \;, with :math:`a, b` corresponding to postive constants such that :math:`a \geq b - 1` and :math:`b \geq 2`, and :math:`k` iteration. """ def __init__(self, a=50., b=2.): """ Parameters ---------- a : float Summand in numerator of update. b : float Factor in denominator of update. """ super(MomentumGenLinear, self).__init__() self.a = a self.b = b
[docs] def update(self, k): """Update momentum coefficient""" return (k + self.a) / self.b