CSC with a Spatial Mask

This example demonstrates the use of cbpdn.AddMaskSim for convolutional sparse coding with a spatial mask [50]. The example problem is inpainting of randomly distributed corruption of a colour image [51].

from __future__ import print_function
from builtins import input

import pyfftw   # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np

from sporco import util
from sporco import signal
from sporco import plot
plot.config_notebook_plotting()
from sporco.admm import tvl2
from sporco.admm import cbpdn
from sporco.metric import psnr

Load a reference image.

img = util.ExampleImages().image('monarch.png', zoom=0.5, scaled=True,
                                 idxexp=np.s_[:, 160:672])

Create random mask and apply to reference image to obtain test image. (The call to numpy.random.seed ensures that the pseudo-random noise is reproducible.)

np.random.seed(12345)
frc = 0.5
msk = signal.rndmask(img.shape, frc, dtype=np.float32)
imgw = msk * img

Define pad and crop functions.

pn = 8
spad = lambda x: np.pad(x, ((pn, pn), (pn, pn), (0, 0)), mode='symmetric')
zpad = lambda x: np.pad(x, ((pn, pn), (pn, pn), (0, 0)), mode='constant')
crop = lambda x: x[pn:-pn, pn:-pn]

Construct padded mask and test image.

mskp = zpad(msk)
imgwp = spad(imgw)

\(\ell_2\)-TV denoising with a spatial mask as a non-linear lowpass filter. The highpass component is the difference between the test image and the lowpass component, multiplied by the mask for faster convergence of the convolutional sparse coding (see [60]).

lmbda = 0.05
opt = tvl2.TVL2Denoise.Options({'Verbose': False, 'MaxMainIter': 200,
                    'DFidWeight': mskp, 'gEvalY': False,
                    'AutoRho': {'Enabled': True}})
b = tvl2.TVL2Denoise(imgwp, lmbda, opt, caxis=2)
sl = b.solve()
sh = mskp * (imgwp - sl)

Load dictionary.

D = util.convdicts()['RGB:8x8x3x64']

Set up admm.cbpdn.ConvBPDN options.

lmbda = 2e-2
opt = cbpdn.ConvBPDN.Options({'Verbose': True, 'MaxMainIter': 200,
                    'HighMemSolve': True, 'RelStopTol': 5e-3,
                    'AuxVarObj': False, 'RelaxParam': 1.8,
                    'rho': 5e1*lmbda + 1e-1, 'AutoRho': {'Enabled': False,
                    'StdResiduals': False}})

Construct admm.cbpdn.AddMaskSim wrapper for admm.cbpdn.ConvBPDN and solve via wrapper. This example could also have made use of admm.cbpdn.ConvBPDNMaskDcpl, which has similar performance in this application, but admm.cbpdn.AddMaskSim has the advantage of greater flexibility in that the wrapper can be applied to a variety of CSC solver objects.

ams = cbpdn.AddMaskSim(cbpdn.ConvBPDN, D, sh, mskp, lmbda, opt=opt)
X = ams.solve()
Itn   Fnc       DFid      Regℓ1     r         s
------------------------------------------------------
   0  3.61e+01  2.40e+00  1.69e+03  9.50e-01  6.10e-01
   1  3.42e+01  4.54e+00  1.48e+03  3.68e-01  5.88e-01
   2  2.96e+01  2.83e+00  1.34e+03  2.11e-01  2.91e-01
   3  2.73e+01  2.39e+00  1.24e+03  1.58e-01  1.98e-01
   4  2.61e+01  2.23e+00  1.19e+03  1.32e-01  1.50e-01
   5  2.53e+01  2.18e+00  1.16e+03  1.15e-01  1.21e-01
   6  2.43e+01  2.18e+00  1.11e+03  1.02e-01  1.04e-01
   7  2.35e+01  2.21e+00  1.06e+03  9.21e-02  9.08e-02
   8  2.27e+01  2.26e+00  1.02e+03  8.40e-02  8.07e-02
   9  2.22e+01  2.31e+00  9.93e+02  7.69e-02  7.22e-02
  10  2.17e+01  2.37e+00  9.67e+02  7.10e-02  6.53e-02
  11  2.13e+01  2.42e+00  9.42e+02  6.58e-02  5.96e-02
  12  2.08e+01  2.46e+00  9.15e+02  6.11e-02  5.48e-02
  13  2.02e+01  2.51e+00  8.87e+02  5.70e-02  5.06e-02
  14  1.97e+01  2.55e+00  8.59e+02  5.33e-02  4.68e-02
  15  1.92e+01  2.59e+00  8.32e+02  4.99e-02  4.37e-02
  16  1.88e+01  2.62e+00  8.07e+02  4.69e-02  4.10e-02
  17  1.83e+01  2.66e+00  7.84e+02  4.42e-02  3.84e-02
  18  1.80e+01  2.69e+00  7.64e+02  4.18e-02  3.59e-02
  19  1.76e+01  2.72e+00  7.45e+02  3.96e-02  3.38e-02
  20  1.73e+01  2.74e+00  7.27e+02  3.75e-02  3.18e-02
  21  1.70e+01  2.77e+00  7.11e+02  3.56e-02  2.99e-02
  22  1.67e+01  2.79e+00  6.96e+02  3.39e-02  2.82e-02
  23  1.65e+01  2.81e+00  6.83e+02  3.22e-02  2.66e-02
  24  1.62e+01  2.83e+00  6.70e+02  3.07e-02  2.52e-02
  25  1.60e+01  2.84e+00  6.58e+02  2.93e-02  2.38e-02
  26  1.58e+01  2.86e+00  6.47e+02  2.79e-02  2.27e-02
  27  1.56e+01  2.87e+00  6.36e+02  2.67e-02  2.16e-02
  28  1.54e+01  2.88e+00  6.25e+02  2.55e-02  2.06e-02
  29  1.52e+01  2.89e+00  6.16e+02  2.44e-02  1.97e-02
  30  1.50e+01  2.90e+00  6.06e+02  2.34e-02  1.89e-02
  31  1.49e+01  2.91e+00  5.98e+02  2.24e-02  1.81e-02
  32  1.47e+01  2.92e+00  5.91e+02  2.15e-02  1.73e-02
  33  1.46e+01  2.93e+00  5.84e+02  2.07e-02  1.65e-02
  34  1.45e+01  2.94e+00  5.78e+02  1.99e-02  1.57e-02
  35  1.44e+01  2.94e+00  5.72e+02  1.91e-02  1.51e-02
  36  1.42e+01  2.95e+00  5.65e+02  1.84e-02  1.45e-02
  37  1.41e+01  2.95e+00  5.59e+02  1.77e-02  1.39e-02
  38  1.40e+01  2.96e+00  5.52e+02  1.71e-02  1.34e-02
  39  1.39e+01  2.96e+00  5.47e+02  1.65e-02  1.29e-02
  40  1.38e+01  2.97e+00  5.41e+02  1.59e-02  1.25e-02
  41  1.37e+01  2.97e+00  5.36e+02  1.54e-02  1.20e-02
  42  1.36e+01  2.97e+00  5.32e+02  1.49e-02  1.16e-02
  43  1.35e+01  2.98e+00  5.27e+02  1.44e-02  1.12e-02
  44  1.34e+01  2.98e+00  5.23e+02  1.39e-02  1.09e-02
  45  1.34e+01  2.98e+00  5.19e+02  1.35e-02  1.05e-02
  46  1.33e+01  2.99e+00  5.15e+02  1.30e-02  1.02e-02
  47  1.32e+01  2.99e+00  5.11e+02  1.26e-02  9.88e-03
  48  1.31e+01  2.99e+00  5.07e+02  1.23e-02  9.56e-03
  49  1.31e+01  3.00e+00  5.04e+02  1.19e-02  9.23e-03
  50  1.30e+01  3.00e+00  5.01e+02  1.15e-02  8.93e-03
  51  1.30e+01  3.00e+00  4.98e+02  1.12e-02  8.65e-03
  52  1.29e+01  3.00e+00  4.95e+02  1.09e-02  8.37e-03
  53  1.28e+01  3.01e+00  4.92e+02  1.06e-02  8.12e-03
  54  1.28e+01  3.01e+00  4.89e+02  1.03e-02  7.86e-03
  55  1.27e+01  3.01e+00  4.86e+02  9.99e-03  7.61e-03
  56  1.27e+01  3.01e+00  4.83e+02  9.72e-03  7.37e-03
  57  1.26e+01  3.02e+00  4.81e+02  9.46e-03  7.15e-03
  58  1.26e+01  3.02e+00  4.79e+02  9.21e-03  6.94e-03
  59  1.25e+01  3.02e+00  4.76e+02  8.96e-03  6.74e-03
  60  1.25e+01  3.02e+00  4.74e+02  8.73e-03  6.56e-03
  61  1.25e+01  3.02e+00  4.72e+02  8.51e-03  6.38e-03
  62  1.24e+01  3.03e+00  4.70e+02  8.29e-03  6.22e-03
  63  1.24e+01  3.03e+00  4.68e+02  8.08e-03  6.05e-03
  64  1.23e+01  3.03e+00  4.66e+02  7.88e-03  5.89e-03
  65  1.23e+01  3.03e+00  4.64e+02  7.69e-03  5.75e-03
  66  1.23e+01  3.03e+00  4.62e+02  7.51e-03  5.62e-03
  67  1.22e+01  3.03e+00  4.60e+02  7.33e-03  5.49e-03
  68  1.22e+01  3.03e+00  4.58e+02  7.15e-03  5.37e-03
  69  1.21e+01  3.04e+00  4.56e+02  6.98e-03  5.25e-03
  70  1.21e+01  3.04e+00  4.54e+02  6.82e-03  5.13e-03
  71  1.21e+01  3.04e+00  4.52e+02  6.67e-03  5.01e-03
  72  1.21e+01  3.04e+00  4.51e+02  6.52e-03  4.90e-03
  73  1.20e+01  3.04e+00  4.49e+02  6.38e-03  4.79e-03
  74  1.20e+01  3.04e+00  4.48e+02  6.24e-03  4.69e-03
  75  1.20e+01  3.04e+00  4.46e+02  6.10e-03  4.60e-03
  76  1.19e+01  3.04e+00  4.45e+02  5.97e-03  4.50e-03
  77  1.19e+01  3.04e+00  4.44e+02  5.84e-03  4.42e-03
  78  1.19e+01  3.05e+00  4.42e+02  5.72e-03  4.33e-03
  79  1.19e+01  3.05e+00  4.41e+02  5.60e-03  4.25e-03
  80  1.18e+01  3.05e+00  4.40e+02  5.49e-03  4.17e-03
  81  1.18e+01  3.05e+00  4.38e+02  5.37e-03  4.10e-03
  82  1.18e+01  3.05e+00  4.37e+02  5.27e-03  4.02e-03
  83  1.18e+01  3.05e+00  4.36e+02  5.17e-03  3.94e-03
  84  1.18e+01  3.05e+00  4.35e+02  5.07e-03  3.86e-03
  85  1.17e+01  3.05e+00  4.34e+02  4.97e-03  3.78e-03
------------------------------------------------------

Reconstruct from representation.

imgr = crop(sl + ams.reconstruct().squeeze())

Display solve time and reconstruction performance.

print("AddMaskSim wrapped ConvBPDN solve time: %.2fs" %
      ams.timer.elapsed('solve'))
print("Corrupted image PSNR: %5.2f dB" % psnr(img, imgw))
print("Recovered image PSNR: %5.2f dB" % psnr(img, imgr))
AddMaskSim wrapped ConvBPDN solve time: 49.15s
Corrupted image PSNR: 10.57 dB
Recovered image PSNR: 29.37 dB

Display reference, test, and reconstructed image

fig = plot.figure(figsize=(21, 7))
plot.subplot(1, 3, 1)
plot.imview(img, title='Reference image', fig=fig)
plot.subplot(1, 3, 2)
plot.imview(imgw, title='Corrupted image', fig=fig)
plot.subplot(1, 3, 3)
plot.imview(imgr, title='Reconstructed image', fig=fig)
fig.show()
../../_images/cbpdn_ams_clr_23_0.png

Display lowpass component and sparse representation

fig = plot.figure(figsize=(14, 7))
plot.subplot(1, 2, 1)
plot.imview(sl, cmap=plot.cm.Blues, title='Lowpass component', fig=fig)
plot.subplot(1, 2, 2)
plot.imview(np.squeeze(np.sum(abs(X), axis=ams.cri.axisM)),
            cmap=plot.cm.Blues, title='Sparse representation', fig=fig)
fig.show()
../../_images/cbpdn_ams_clr_25_0.png

Plot functional value, residuals, and rho

its = ams.getitstat()
fig = plot.figure(figsize=(21, 7))
plot.subplot(1, 3, 1)
plot.plot(its.ObjFun, xlbl='Iterations', ylbl='Functional', fig=fig)
plot.subplot(1, 3, 2)
plot.plot(np.vstack((its.PrimalRsdl, its.DualRsdl)).T, ptyp='semilogy',
          xlbl='Iterations', ylbl='Residual', lgnd=['Primal', 'Dual'],
          fig=fig)
plot.subplot(1, 3, 3)
plot.plot(its.Rho, xlbl='Iterations', ylbl='Penalty Parameter', fig=fig)
fig.show()
../../_images/cbpdn_ams_clr_27_0.png