CSC with a Spatial Mask¶
This example demonstrates the use of cbpdn.AddMaskSim
for
convolutional sparse coding with a spatial mask
[50]. The example problem is inpainting of
randomly distributed corruption of a colour image
[51].
from __future__ import print_function
from builtins import input
import pyfftw # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np
from sporco import util
from sporco import signal
from sporco import plot
plot.config_notebook_plotting()
from sporco.admm import tvl2
from sporco.admm import cbpdn
from sporco.metric import psnr
Load a reference image.
img = util.ExampleImages().image('monarch.png', zoom=0.5, scaled=True,
idxexp=np.s_[:, 160:672])
Create random mask and apply to reference image to obtain test image.
(The call to numpy.random.seed
ensures that the pseudo-random noise
is reproducible.)
np.random.seed(12345)
frc = 0.5
msk = signal.rndmask(img.shape, frc, dtype=np.float32)
imgw = msk * img
Define pad and crop functions.
pn = 8
spad = lambda x: np.pad(x, ((pn, pn), (pn, pn), (0, 0)), mode='symmetric')
zpad = lambda x: np.pad(x, ((pn, pn), (pn, pn), (0, 0)), mode='constant')
crop = lambda x: x[pn:-pn, pn:-pn]
Construct padded mask and test image.
mskp = zpad(msk)
imgwp = spad(imgw)
\(\ell_2\)-TV denoising with a spatial mask as a non-linear lowpass filter. The highpass component is the difference between the test image and the lowpass component, multiplied by the mask for faster convergence of the convolutional sparse coding (see [60]).
lmbda = 0.05
opt = tvl2.TVL2Denoise.Options({'Verbose': False, 'MaxMainIter': 200,
'DFidWeight': mskp, 'gEvalY': False,
'AutoRho': {'Enabled': True}})
b = tvl2.TVL2Denoise(imgwp, lmbda, opt, caxis=2)
sl = b.solve()
sh = mskp * (imgwp - sl)
Load dictionary.
D = util.convdicts()['RGB:8x8x3x64']
Set up admm.cbpdn.ConvBPDN
options.
lmbda = 2e-2
opt = cbpdn.ConvBPDN.Options({'Verbose': True, 'MaxMainIter': 200,
'HighMemSolve': True, 'RelStopTol': 5e-3,
'AuxVarObj': False, 'RelaxParam': 1.8,
'rho': 5e1*lmbda + 1e-1, 'AutoRho': {'Enabled': False,
'StdResiduals': False}})
Construct admm.cbpdn.AddMaskSim
wrapper for
admm.cbpdn.ConvBPDN
and solve via wrapper. This example
could also have made use of admm.cbpdn.ConvBPDNMaskDcpl
,
which has similar performance in this application, but
admm.cbpdn.AddMaskSim
has the advantage of greater
flexibility in that the wrapper can be applied to a variety of CSC
solver objects.
ams = cbpdn.AddMaskSim(cbpdn.ConvBPDN, D, sh, mskp, lmbda, opt=opt)
X = ams.solve()
Itn Fnc DFid Regℓ1 r s
------------------------------------------------------
0 3.61e+01 2.40e+00 1.69e+03 9.50e-01 6.10e-01
1 3.42e+01 4.54e+00 1.48e+03 3.68e-01 5.88e-01
2 2.96e+01 2.83e+00 1.34e+03 2.11e-01 2.91e-01
3 2.73e+01 2.39e+00 1.24e+03 1.58e-01 1.98e-01
4 2.61e+01 2.23e+00 1.19e+03 1.32e-01 1.50e-01
5 2.53e+01 2.18e+00 1.16e+03 1.15e-01 1.21e-01
6 2.43e+01 2.18e+00 1.11e+03 1.02e-01 1.04e-01
7 2.35e+01 2.21e+00 1.06e+03 9.21e-02 9.08e-02
8 2.27e+01 2.26e+00 1.02e+03 8.40e-02 8.07e-02
9 2.22e+01 2.31e+00 9.93e+02 7.69e-02 7.22e-02
10 2.17e+01 2.37e+00 9.67e+02 7.10e-02 6.53e-02
11 2.13e+01 2.42e+00 9.42e+02 6.58e-02 5.96e-02
12 2.08e+01 2.46e+00 9.15e+02 6.11e-02 5.48e-02
13 2.02e+01 2.51e+00 8.87e+02 5.70e-02 5.06e-02
14 1.97e+01 2.55e+00 8.59e+02 5.33e-02 4.68e-02
15 1.92e+01 2.59e+00 8.32e+02 4.99e-02 4.37e-02
16 1.88e+01 2.62e+00 8.07e+02 4.69e-02 4.10e-02
17 1.83e+01 2.66e+00 7.84e+02 4.42e-02 3.84e-02
18 1.80e+01 2.69e+00 7.64e+02 4.18e-02 3.59e-02
19 1.76e+01 2.72e+00 7.45e+02 3.96e-02 3.38e-02
20 1.73e+01 2.74e+00 7.27e+02 3.75e-02 3.18e-02
21 1.70e+01 2.77e+00 7.11e+02 3.56e-02 2.99e-02
22 1.67e+01 2.79e+00 6.96e+02 3.39e-02 2.82e-02
23 1.65e+01 2.81e+00 6.83e+02 3.22e-02 2.66e-02
24 1.62e+01 2.83e+00 6.70e+02 3.07e-02 2.52e-02
25 1.60e+01 2.84e+00 6.58e+02 2.93e-02 2.38e-02
26 1.58e+01 2.86e+00 6.47e+02 2.79e-02 2.27e-02
27 1.56e+01 2.87e+00 6.36e+02 2.67e-02 2.16e-02
28 1.54e+01 2.88e+00 6.25e+02 2.55e-02 2.06e-02
29 1.52e+01 2.89e+00 6.16e+02 2.44e-02 1.97e-02
30 1.50e+01 2.90e+00 6.06e+02 2.34e-02 1.89e-02
31 1.49e+01 2.91e+00 5.98e+02 2.24e-02 1.81e-02
32 1.47e+01 2.92e+00 5.91e+02 2.15e-02 1.73e-02
33 1.46e+01 2.93e+00 5.84e+02 2.07e-02 1.65e-02
34 1.45e+01 2.94e+00 5.78e+02 1.99e-02 1.57e-02
35 1.44e+01 2.94e+00 5.72e+02 1.91e-02 1.51e-02
36 1.42e+01 2.95e+00 5.65e+02 1.84e-02 1.45e-02
37 1.41e+01 2.95e+00 5.59e+02 1.77e-02 1.39e-02
38 1.40e+01 2.96e+00 5.52e+02 1.71e-02 1.34e-02
39 1.39e+01 2.96e+00 5.47e+02 1.65e-02 1.29e-02
40 1.38e+01 2.97e+00 5.41e+02 1.59e-02 1.25e-02
41 1.37e+01 2.97e+00 5.36e+02 1.54e-02 1.20e-02
42 1.36e+01 2.97e+00 5.32e+02 1.49e-02 1.16e-02
43 1.35e+01 2.98e+00 5.27e+02 1.44e-02 1.12e-02
44 1.34e+01 2.98e+00 5.23e+02 1.39e-02 1.09e-02
45 1.34e+01 2.98e+00 5.19e+02 1.35e-02 1.05e-02
46 1.33e+01 2.99e+00 5.15e+02 1.30e-02 1.02e-02
47 1.32e+01 2.99e+00 5.11e+02 1.26e-02 9.88e-03
48 1.31e+01 2.99e+00 5.07e+02 1.23e-02 9.56e-03
49 1.31e+01 3.00e+00 5.04e+02 1.19e-02 9.23e-03
50 1.30e+01 3.00e+00 5.01e+02 1.15e-02 8.93e-03
51 1.30e+01 3.00e+00 4.98e+02 1.12e-02 8.65e-03
52 1.29e+01 3.00e+00 4.95e+02 1.09e-02 8.37e-03
53 1.28e+01 3.01e+00 4.92e+02 1.06e-02 8.12e-03
54 1.28e+01 3.01e+00 4.89e+02 1.03e-02 7.86e-03
55 1.27e+01 3.01e+00 4.86e+02 9.99e-03 7.61e-03
56 1.27e+01 3.01e+00 4.83e+02 9.72e-03 7.37e-03
57 1.26e+01 3.02e+00 4.81e+02 9.46e-03 7.15e-03
58 1.26e+01 3.02e+00 4.79e+02 9.21e-03 6.94e-03
59 1.25e+01 3.02e+00 4.76e+02 8.96e-03 6.74e-03
60 1.25e+01 3.02e+00 4.74e+02 8.73e-03 6.56e-03
61 1.25e+01 3.02e+00 4.72e+02 8.51e-03 6.38e-03
62 1.24e+01 3.03e+00 4.70e+02 8.29e-03 6.22e-03
63 1.24e+01 3.03e+00 4.68e+02 8.08e-03 6.05e-03
64 1.23e+01 3.03e+00 4.66e+02 7.88e-03 5.89e-03
65 1.23e+01 3.03e+00 4.64e+02 7.69e-03 5.75e-03
66 1.23e+01 3.03e+00 4.62e+02 7.51e-03 5.62e-03
67 1.22e+01 3.03e+00 4.60e+02 7.33e-03 5.49e-03
68 1.22e+01 3.03e+00 4.58e+02 7.15e-03 5.37e-03
69 1.21e+01 3.04e+00 4.56e+02 6.98e-03 5.25e-03
70 1.21e+01 3.04e+00 4.54e+02 6.82e-03 5.13e-03
71 1.21e+01 3.04e+00 4.52e+02 6.67e-03 5.01e-03
72 1.21e+01 3.04e+00 4.51e+02 6.52e-03 4.90e-03
73 1.20e+01 3.04e+00 4.49e+02 6.38e-03 4.79e-03
74 1.20e+01 3.04e+00 4.48e+02 6.24e-03 4.69e-03
75 1.20e+01 3.04e+00 4.46e+02 6.10e-03 4.60e-03
76 1.19e+01 3.04e+00 4.45e+02 5.97e-03 4.50e-03
77 1.19e+01 3.04e+00 4.44e+02 5.84e-03 4.42e-03
78 1.19e+01 3.05e+00 4.42e+02 5.72e-03 4.33e-03
79 1.19e+01 3.05e+00 4.41e+02 5.60e-03 4.25e-03
80 1.18e+01 3.05e+00 4.40e+02 5.49e-03 4.17e-03
81 1.18e+01 3.05e+00 4.38e+02 5.37e-03 4.10e-03
82 1.18e+01 3.05e+00 4.37e+02 5.27e-03 4.02e-03
83 1.18e+01 3.05e+00 4.36e+02 5.17e-03 3.94e-03
84 1.18e+01 3.05e+00 4.35e+02 5.07e-03 3.86e-03
85 1.17e+01 3.05e+00 4.34e+02 4.97e-03 3.78e-03
------------------------------------------------------
Reconstruct from representation.
imgr = crop(sl + ams.reconstruct().squeeze())
Display solve time and reconstruction performance.
print("AddMaskSim wrapped ConvBPDN solve time: %.2fs" %
ams.timer.elapsed('solve'))
print("Corrupted image PSNR: %5.2f dB" % psnr(img, imgw))
print("Recovered image PSNR: %5.2f dB" % psnr(img, imgr))
AddMaskSim wrapped ConvBPDN solve time: 49.15s
Corrupted image PSNR: 10.57 dB
Recovered image PSNR: 29.37 dB
Display reference, test, and reconstructed image
fig = plot.figure(figsize=(21, 7))
plot.subplot(1, 3, 1)
plot.imview(img, title='Reference image', fig=fig)
plot.subplot(1, 3, 2)
plot.imview(imgw, title='Corrupted image', fig=fig)
plot.subplot(1, 3, 3)
plot.imview(imgr, title='Reconstructed image', fig=fig)
fig.show()
Display lowpass component and sparse representation
fig = plot.figure(figsize=(14, 7))
plot.subplot(1, 2, 1)
plot.imview(sl, cmap=plot.cm.Blues, title='Lowpass component', fig=fig)
plot.subplot(1, 2, 2)
plot.imview(np.squeeze(np.sum(abs(X), axis=ams.cri.axisM)),
cmap=plot.cm.Blues, title='Sparse representation', fig=fig)
fig.show()
Plot functional value, residuals, and rho
its = ams.getitstat()
fig = plot.figure(figsize=(21, 7))
plot.subplot(1, 3, 1)
plot.plot(its.ObjFun, xlbl='Iterations', ylbl='Functional', fig=fig)
plot.subplot(1, 3, 2)
plot.plot(np.vstack((its.PrimalRsdl, its.DualRsdl)).T, ptyp='semilogy',
xlbl='Iterations', ylbl='Residual', lgnd=['Primal', 'Dual'],
fig=fig)
plot.subplot(1, 3, 3)
plot.plot(its.Rho, xlbl='Iterations', ylbl='Penalty Parameter', fig=fig)
fig.show()