CUDA Convolutional Sparse Coding with Gradient Term¶
This example demonstrates the use of the interface to the CUDA CSC solver extension package, with a test for the availablity of a GPU that runs the Python version of the ConvBPDNGradReg solver if one is not available, or if the extension package is not installed.
from __future__ import print_function
from builtins import input
import pyfftw # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np
from sporco import util
from sporco import fft
from sporco import metric
from sporco import plot
plot.config_notebook_plotting()
from sporco import cuda
from sporco.admm import cbpdn
# If running in a notebook, try to use wurlitzer so that output from the CUDA
# code will be properly captured in the notebook.
sys_pipes = util.notebook_system_output()
Load example image.
img = util.ExampleImages().image('barbara.png', scaled=True, gray=True,
idxexp=np.s_[10:522, 100:612])
Load main dictionary and prepend an impulse filter for lowpass component representation.
Db = util.convdicts()['G:12x12x36']
di = np.zeros(Db.shape[0:2] + (1,), dtype=np.float32)
di[0, 0] = 1
D = np.concatenate((di, Db), axis=2)
Set up weights for the \(\ell_1\) norm to disable regularization of the coefficient map corresponding to the impulse filter.
wl1 = np.ones((1,)*2 + (D.shape[2:]), dtype=np.float32)
wl1[..., 0] = 0.0
Set of weights for the \(\ell_2\) norm of the gradient to disable regularization of all coefficient maps except for the one corresponding to the impulse filter.
wgr = np.zeros((D.shape[2]), dtype=np.float32)
wgr[0] = 1.0
Set up admm.cbpdn.ConvBPDNGradReg options.
lmbda = 1e-2
mu = 5e-1
opt = cbpdn.ConvBPDNGradReg.Options({'Verbose': True, 'MaxMainIter': 250,
'HighMemSolve': True, 'RelStopTol': 5e-3,
'AuxVarObj': False, 'AutoRho': {'Enabled': False},
'rho': 0.5, 'L1Weight': wl1, 'GradWeight': wgr})
If GPU available, run CUDA ConvBPDNGradReg solver, otherwise run standard Python version.
if cuda.device_count() > 0:
print('%s GPU found: running CUDA solver' % cuda.device_name())
tm = util.Timer()
with sys_pipes(), util.ContextTimer(tm):
X = cuda.cbpdngrd(D, img, lmbda, mu, opt)
t = tm.elapsed()
else:
print('GPU not found: running Python solver')
c = cbpdn.ConvBPDNGradReg(D, img, lmbda, mu, opt)
X = c.solve().squeeze()
t = c.timer.elapsed('solve')
print('Solve time: %.2f s' % t)
GPU not found: running Python solver
Itn Fnc DFid Regℓ1 Regℓ2∇ r s
----------------------------------------------------------------
0 3.50e+03 3.02e+03 4.85e+04 3.62e-01 4.45e-01 9.48e+00
1 5.78e+02 1.27e+02 4.50e+04 1.88e+00 1.19e-01 2.28e+00
2 3.60e+02 9.87e+00 3.49e+04 3.06e+00 6.07e-02 7.70e-01
3 2.86e+02 4.87e+00 2.79e+04 4.19e+00 4.56e-02 5.12e-01
4 2.34e+02 3.36e+00 2.27e+04 5.51e+00 3.64e-02 4.01e-01
5 2.00e+02 2.92e+00 1.93e+04 6.84e+00 3.00e-02 3.33e-01
6 1.72e+02 2.79e+00 1.65e+04 8.18e+00 2.50e-02 2.85e-01
7 1.51e+02 2.73e+00 1.43e+04 9.54e+00 2.09e-02 2.53e-01
8 1.34e+02 2.67e+00 1.26e+04 1.09e+01 1.78e-02 2.24e-01
9 1.20e+02 2.62e+00 1.12e+04 1.22e+01 1.52e-02 2.01e-01
10 1.10e+02 2.57e+00 1.01e+04 1.35e+01 1.32e-02 1.80e-01
11 1.02e+02 2.53e+00 9.22e+03 1.47e+01 1.16e-02 1.60e-01
12 9.58e+01 2.51e+00 8.53e+03 1.59e+01 1.03e-02 1.43e-01
13 9.05e+01 2.49e+00 7.96e+03 1.69e+01 9.15e-03 1.29e-01
14 8.57e+01 2.49e+00 7.43e+03 1.78e+01 8.19e-03 1.18e-01
15 8.17e+01 2.50e+00 6.99e+03 1.87e+01 7.37e-03 1.09e-01
16 7.85e+01 2.50e+00 6.63e+03 1.94e+01 6.68e-03 1.00e-01
17 7.62e+01 2.50e+00 6.37e+03 2.00e+01 6.10e-03 9.21e-02
18 7.43e+01 2.50e+00 6.15e+03 2.06e+01 5.60e-03 8.43e-02
19 7.25e+01 2.50e+00 5.95e+03 2.10e+01 5.16e-03 7.72e-02
20 7.09e+01 2.51e+00 5.77e+03 2.14e+01 4.77e-03 7.11e-02
21 6.94e+01 2.52e+00 5.60e+03 2.18e+01 4.41e-03 6.61e-02
22 6.82e+01 2.53e+00 5.46e+03 2.21e+01 4.09e-03 6.18e-02
23 6.70e+01 2.54e+00 5.33e+03 2.24e+01 3.81e-03 5.77e-02
24 6.61e+01 2.55e+00 5.23e+03 2.26e+01 3.56e-03 5.36e-02
25 6.52e+01 2.56e+00 5.13e+03 2.28e+01 3.33e-03 4.98e-02
26 6.44e+01 2.57e+00 5.04e+03 2.30e+01 3.12e-03 4.64e-02
27 6.37e+01 2.57e+00 4.96e+03 2.31e+01 2.93e-03 4.34e-02
28 6.31e+01 2.58e+00 4.89e+03 2.32e+01 2.76e-03 4.06e-02
29 6.25e+01 2.59e+00 4.82e+03 2.33e+01 2.60e-03 3.82e-02
30 6.19e+01 2.60e+00 4.76e+03 2.34e+01 2.45e-03 3.60e-02
31 6.14e+01 2.60e+00 4.70e+03 2.35e+01 2.32e-03 3.42e-02
32 6.08e+01 2.61e+00 4.64e+03 2.36e+01 2.19e-03 3.27e-02
33 6.04e+01 2.61e+00 4.59e+03 2.36e+01 2.07e-03 3.12e-02
34 5.99e+01 2.62e+00 4.55e+03 2.37e+01 1.96e-03 2.95e-02
35 5.96e+01 2.62e+00 4.51e+03 2.37e+01 1.86e-03 2.78e-02
36 5.93e+01 2.63e+00 4.48e+03 2.38e+01 1.77e-03 2.61e-02
37 5.89e+01 2.63e+00 4.44e+03 2.38e+01 1.69e-03 2.47e-02
38 5.86e+01 2.63e+00 4.41e+03 2.38e+01 1.61e-03 2.34e-02
39 5.83e+01 2.64e+00 4.38e+03 2.39e+01 1.53e-03 2.23e-02
40 5.80e+01 2.64e+00 4.34e+03 2.39e+01 1.46e-03 2.13e-02
41 5.77e+01 2.64e+00 4.31e+03 2.39e+01 1.39e-03 2.04e-02
42 5.75e+01 2.65e+00 4.29e+03 2.39e+01 1.33e-03 1.95e-02
43 5.72e+01 2.65e+00 4.26e+03 2.40e+01 1.27e-03 1.87e-02
44 5.70e+01 2.65e+00 4.24e+03 2.40e+01 1.21e-03 1.78e-02
45 5.68e+01 2.65e+00 4.22e+03 2.40e+01 1.16e-03 1.70e-02
46 5.66e+01 2.66e+00 4.20e+03 2.40e+01 1.11e-03 1.62e-02
47 5.64e+01 2.66e+00 4.18e+03 2.40e+01 1.06e-03 1.55e-02
48 5.62e+01 2.66e+00 4.16e+03 2.40e+01 1.02e-03 1.49e-02
49 5.61e+01 2.66e+00 4.14e+03 2.40e+01 9.72e-04 1.43e-02
50 5.59e+01 2.66e+00 4.12e+03 2.40e+01 9.31e-04 1.37e-02
51 5.57e+01 2.67e+00 4.10e+03 2.41e+01 8.91e-04 1.32e-02
52 5.56e+01 2.67e+00 4.09e+03 2.41e+01 8.55e-04 1.27e-02
53 5.54e+01 2.67e+00 4.07e+03 2.41e+01 8.20e-04 1.21e-02
54 5.53e+01 2.67e+00 4.06e+03 2.41e+01 7.88e-04 1.16e-02
55 5.52e+01 2.67e+00 4.05e+03 2.41e+01 7.57e-04 1.11e-02
56 5.51e+01 2.67e+00 4.03e+03 2.41e+01 7.28e-04 1.06e-02
57 5.50e+01 2.68e+00 4.02e+03 2.41e+01 6.99e-04 1.02e-02
58 5.48e+01 2.68e+00 4.01e+03 2.41e+01 6.71e-04 9.83e-03
59 5.47e+01 2.68e+00 4.00e+03 2.41e+01 6.45e-04 9.46e-03
60 5.46e+01 2.68e+00 3.99e+03 2.41e+01 6.20e-04 9.11e-03
61 5.45e+01 2.68e+00 3.98e+03 2.41e+01 5.97e-04 8.76e-03
62 5.44e+01 2.68e+00 3.97e+03 2.41e+01 5.74e-04 8.42e-03
63 5.43e+01 2.68e+00 3.96e+03 2.41e+01 5.53e-04 8.11e-03
64 5.43e+01 2.68e+00 3.95e+03 2.41e+01 5.32e-04 7.83e-03
65 5.42e+01 2.68e+00 3.94e+03 2.41e+01 5.12e-04 7.56e-03
66 5.41e+01 2.68e+00 3.93e+03 2.41e+01 4.93e-04 7.29e-03
67 5.40e+01 2.68e+00 3.93e+03 2.41e+01 4.75e-04 7.03e-03
68 5.39e+01 2.69e+00 3.92e+03 2.41e+01 4.58e-04 6.77e-03
69 5.39e+01 2.69e+00 3.91e+03 2.41e+01 4.42e-04 6.50e-03
70 5.38e+01 2.69e+00 3.91e+03 2.41e+01 4.27e-04 6.26e-03
71 5.38e+01 2.69e+00 3.90e+03 2.41e+01 4.12e-04 6.04e-03
72 5.37e+01 2.69e+00 3.89e+03 2.41e+01 3.97e-04 5.84e-03
73 5.36e+01 2.69e+00 3.89e+03 2.42e+01 3.83e-04 5.65e-03
74 5.36e+01 2.69e+00 3.88e+03 2.42e+01 3.70e-04 5.47e-03
75 5.35e+01 2.69e+00 3.87e+03 2.42e+01 3.57e-04 5.29e-03
76 5.35e+01 2.69e+00 3.87e+03 2.42e+01 3.45e-04 5.11e-03
77 5.34e+01 2.69e+00 3.86e+03 2.42e+01 3.33e-04 4.95e-03
----------------------------------------------------------------
Solve time: 48.24 s
Reconstruct the image from the sparse representation.
imgr = np.sum(fft.fftconv(D, X, axes=(0, 1)), axis=2)
print("Reconstruction PSNR: %.2fdB\n" % metric.psnr(img, imgr))
Reconstruction PSNR: 45.50dB
Display representation and reconstructed image.
fig = plot.figure(figsize=(14, 14))
plot.subplot(2, 2, 1)
plot.imview(X[..., 0].squeeze(), title='Lowpass component', fig=fig)
plot.subplot(2, 2, 2)
plot.imview(np.sum(abs(X[..., 1:]), axis=2).squeeze(),
cmap=plot.cm.Blues, title='Main representation', fig=fig)
plot.subplot(2, 2, 3)
plot.imview(imgr, title='Reconstructed image', fig=fig)
plot.subplot(2, 2, 4)
plot.imview(imgr - img, fltscl=True, title='Reconstruction difference',
fig=fig)
fig.show()