CUDA Convolutional Sparse Coding with Gradient Term

This example demonstrates the use of the interface to the CUDA CSC solver extension package, with a test for the availablity of a GPU that runs the Python version of the ConvBPDNGradReg solver if one is not available, or if the extension package is not installed.

from __future__ import print_function
from builtins import input

import pyfftw   # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np

from sporco import util
from sporco import fft
from sporco import metric
from sporco import plot
plot.config_notebook_plotting()
from sporco import cuda
from sporco.admm import cbpdn

# If running in a notebook, try to use wurlitzer so that output from the CUDA
# code will be properly captured in the notebook.
sys_pipes = util.notebook_system_output()

Load example image.

img = util.ExampleImages().image('barbara.png', scaled=True, gray=True,
                                 idxexp=np.s_[10:522, 100:612])

Load main dictionary and prepend an impulse filter for lowpass component representation.

Db = util.convdicts()['G:12x12x36']
di = np.zeros(Db.shape[0:2] + (1,), dtype=np.float32)
di[0, 0] = 1
D = np.concatenate((di, Db), axis=2)

Set up weights for the \(\ell_1\) norm to disable regularization of the coefficient map corresponding to the impulse filter.

wl1 = np.ones((1,)*2 + (D.shape[2:]), dtype=np.float32)
wl1[..., 0] = 0.0

Set of weights for the \(\ell_2\) norm of the gradient to disable regularization of all coefficient maps except for the one corresponding to the impulse filter.

wgr = np.zeros((D.shape[2]), dtype=np.float32)
wgr[0] = 1.0

Set up admm.cbpdn.ConvBPDNGradReg options.

lmbda = 1e-2
mu = 5e-1
opt = cbpdn.ConvBPDNGradReg.Options({'Verbose': True, 'MaxMainIter': 250,
                    'HighMemSolve': True, 'RelStopTol': 5e-3,
                    'AuxVarObj': False, 'AutoRho': {'Enabled': False},
                    'rho': 0.5, 'L1Weight': wl1, 'GradWeight': wgr})

If GPU available, run CUDA ConvBPDNGradReg solver, otherwise run standard Python version.

if cuda.device_count() > 0:
    print('%s GPU found: running CUDA solver' % cuda.device_name())
    tm = util.Timer()
    with sys_pipes(), util.ContextTimer(tm):
        X = cuda.cbpdngrd(D, img, lmbda, mu, opt)
    t = tm.elapsed()
else:
    print('GPU not found: running Python solver')
    c = cbpdn.ConvBPDNGradReg(D, img, lmbda, mu, opt)
    X = c.solve().squeeze()
    t = c.timer.elapsed('solve')
print('Solve time: %.2f s' % t)
GeForce RTX 2080 Ti GPU found: running CUDA solver
Itn   Fnc       DFid      Regℓ1     Regℓ2∇     r         s         ρ
--------------------------------------------------------------------------
   0  2.68e+07  1.30e+07  4.85e+04  2.76e+07  4.45e-01  9.48e+00  5.00e-01
   1  6.94e+07  3.32e+07  4.50e+04  7.26e+07  1.19e-01  2.28e+00  5.00e-01
   2  5.95e+07  2.79e+07  3.49e+04  6.33e+07  6.07e-02  7.70e-01  5.00e-01
   3  6.15e+07  2.87e+07  2.79e+04  6.57e+07  4.56e-02  5.12e-01  5.00e-01
   4  6.17e+07  2.87e+07  2.27e+04  6.59e+07  3.64e-02  4.01e-01  5.00e-01
   5  6.17e+07  2.87e+07  1.93e+04  6.60e+07  3.00e-02  3.33e-01  5.00e-01
   6  6.21e+07  2.90e+07  1.65e+04  6.62e+07  2.50e-02  2.85e-01  5.00e-01
   7  6.21e+07  2.89e+07  1.43e+04  6.64e+07  2.09e-02  2.53e-01  5.00e-01
   8  6.22e+07  2.90e+07  1.26e+04  6.66e+07  1.78e-02  2.24e-01  5.00e-01
   9  6.23e+07  2.89e+07  1.12e+04  6.67e+07  1.52e-02  2.01e-01  5.00e-01
  10  6.24e+07  2.90e+07  1.01e+04  6.68e+07  1.32e-02  1.80e-01  5.00e-01
  11  6.24e+07  2.89e+07  9.22e+03  6.69e+07  1.16e-02  1.60e-01  5.00e-01
  12  6.24e+07  2.90e+07  8.53e+03  6.69e+07  1.03e-02  1.43e-01  5.00e-01
  13  6.25e+07  2.90e+07  7.96e+03  6.70e+07  9.15e-03  1.29e-01  5.00e-01
  14  6.26e+07  2.90e+07  7.43e+03  6.71e+07  8.19e-03  1.18e-01  5.00e-01
  15  6.26e+07  2.90e+07  6.99e+03  6.71e+07  7.37e-03  1.09e-01  5.00e-01
  16  6.26e+07  2.91e+07  6.63e+03  6.71e+07  6.68e-03  1.00e-01  5.00e-01
  17  6.27e+07  2.91e+07  6.37e+03  6.71e+07  6.10e-03  9.21e-02  5.00e-01
  18  6.27e+07  2.91e+07  6.15e+03  6.72e+07  5.60e-03  8.43e-02  5.00e-01
  19  6.27e+07  2.91e+07  5.95e+03  6.72e+07  5.16e-03  7.72e-02  5.00e-01
  20  6.27e+07  2.91e+07  5.77e+03  6.72e+07  4.77e-03  7.11e-02  5.00e-01
  21  6.27e+07  2.91e+07  5.60e+03  6.72e+07  4.41e-03  6.61e-02  5.00e-01
  22  6.27e+07  2.91e+07  5.46e+03  6.72e+07  4.09e-03  6.18e-02  5.00e-01
  23  6.27e+07  2.91e+07  5.33e+03  6.72e+07  3.81e-03  5.77e-02  5.00e-01
  24  6.27e+07  2.91e+07  5.23e+03  6.72e+07  3.56e-03  5.36e-02  5.00e-01
  25  6.27e+07  2.91e+07  5.13e+03  6.72e+07  3.33e-03  4.98e-02  5.00e-01
  26  6.27e+07  2.91e+07  5.04e+03  6.72e+07  3.12e-03  4.64e-02  5.00e-01
  27  6.27e+07  2.91e+07  4.96e+03  6.72e+07  2.93e-03  4.34e-02  5.00e-01
  28  6.27e+07  2.91e+07  4.89e+03  6.72e+07  2.76e-03  4.06e-02  5.00e-01
  29  6.27e+07  2.91e+07  4.82e+03  6.72e+07  2.60e-03  3.82e-02  5.00e-01
  30  6.27e+07  2.91e+07  4.76e+03  6.72e+07  2.45e-03  3.60e-02  5.00e-01
  31  6.27e+07  2.91e+07  4.70e+03  6.72e+07  2.32e-03  3.42e-02  5.00e-01
  32  6.27e+07  2.91e+07  4.64e+03  6.72e+07  2.19e-03  3.27e-02  5.00e-01
  33  6.27e+07  2.91e+07  4.59e+03  6.72e+07  2.07e-03  3.12e-02  5.00e-01
  34  6.27e+07  2.91e+07  4.55e+03  6.72e+07  1.96e-03  2.95e-02  5.00e-01
  35  6.27e+07  2.91e+07  4.51e+03  6.72e+07  1.86e-03  2.78e-02  5.00e-01
  36  6.27e+07  2.91e+07  4.48e+03  6.72e+07  1.77e-03  2.61e-02  5.00e-01
  37  6.27e+07  2.91e+07  4.44e+03  6.72e+07  1.69e-03  2.47e-02  5.00e-01
  38  6.27e+07  2.91e+07  4.41e+03  6.72e+07  1.61e-03  2.34e-02  5.00e-01
  39  6.27e+07  2.91e+07  4.38e+03  6.72e+07  1.53e-03  2.23e-02  5.00e-01
  40  6.27e+07  2.91e+07  4.34e+03  6.72e+07  1.46e-03  2.13e-02  5.00e-01
  41  6.27e+07  2.91e+07  4.31e+03  6.72e+07  1.39e-03  2.04e-02  5.00e-01
  42  6.27e+07  2.91e+07  4.29e+03  6.72e+07  1.33e-03  1.95e-02  5.00e-01
  43  6.27e+07  2.91e+07  4.26e+03  6.72e+07  1.27e-03  1.87e-02  5.00e-01
  44  6.27e+07  2.91e+07  4.24e+03  6.72e+07  1.21e-03  1.78e-02  5.00e-01
  45  6.27e+07  2.91e+07  4.22e+03  6.72e+07  1.16e-03  1.70e-02  5.00e-01
  46  6.27e+07  2.91e+07  4.20e+03  6.72e+07  1.11e-03  1.62e-02  5.00e-01
  47  6.27e+07  2.91e+07  4.18e+03  6.72e+07  1.06e-03  1.55e-02  5.00e-01
  48  6.27e+07  2.91e+07  4.16e+03  6.72e+07  1.02e-03  1.49e-02  5.00e-01
  49  6.27e+07  2.91e+07  4.14e+03  6.72e+07  9.72e-04  1.43e-02  5.00e-01
  50  6.27e+07  2.91e+07  4.12e+03  6.72e+07  9.31e-04  1.37e-02  5.00e-01
  51  6.27e+07  2.91e+07  4.10e+03  6.72e+07  8.91e-04  1.32e-02  5.00e-01
  52  6.27e+07  2.91e+07  4.09e+03  6.72e+07  8.55e-04  1.27e-02  5.00e-01
  53  6.27e+07  2.91e+07  4.07e+03  6.72e+07  8.20e-04  1.21e-02  5.00e-01
  54  6.27e+07  2.91e+07  4.06e+03  6.72e+07  7.88e-04  1.16e-02  5.00e-01
  55  6.27e+07  2.91e+07  4.05e+03  6.72e+07  7.57e-04  1.11e-02  5.00e-01
  56  6.27e+07  2.91e+07  4.03e+03  6.72e+07  7.28e-04  1.06e-02  5.00e-01
  57  6.27e+07  2.91e+07  4.02e+03  6.72e+07  6.99e-04  1.02e-02  5.00e-01
  58  6.27e+07  2.91e+07  4.01e+03  6.72e+07  6.71e-04  9.83e-03  5.00e-01
  59  6.27e+07  2.91e+07  4.00e+03  6.72e+07  6.45e-04  9.46e-03  5.00e-01
  60  6.27e+07  2.91e+07  3.99e+03  6.72e+07  6.20e-04  9.11e-03  5.00e-01
  61  6.27e+07  2.91e+07  3.98e+03  6.72e+07  5.97e-04  8.76e-03  5.00e-01
  62  6.27e+07  2.91e+07  3.97e+03  6.72e+07  5.74e-04  8.42e-03  5.00e-01
  63  6.27e+07  2.91e+07  3.96e+03  6.72e+07  5.53e-04  8.11e-03  5.00e-01
  64  6.27e+07  2.91e+07  3.95e+03  6.72e+07  5.32e-04  7.83e-03  5.00e-01
  65  6.27e+07  2.91e+07  3.94e+03  6.72e+07  5.12e-04  7.56e-03  5.00e-01
  66  6.27e+07  2.91e+07  3.93e+03  6.72e+07  4.93e-04  7.29e-03  5.00e-01
  67  6.27e+07  2.91e+07  3.93e+03  6.72e+07  4.75e-04  7.03e-03  5.00e-01
  68  6.27e+07  2.91e+07  3.92e+03  6.72e+07  4.58e-04  6.77e-03  5.00e-01
  69  6.27e+07  2.91e+07  3.91e+03  6.72e+07  4.42e-04  6.50e-03  5.00e-01
  70  6.27e+07  2.91e+07  3.91e+03  6.72e+07  4.27e-04  6.26e-03  5.00e-01
  71  6.27e+07  2.91e+07  3.90e+03  6.72e+07  4.12e-04  6.04e-03  5.00e-01
  72  6.27e+07  2.91e+07  3.89e+03  6.72e+07  3.97e-04  5.84e-03  5.00e-01
  73  6.27e+07  2.91e+07  3.89e+03  6.72e+07  3.83e-04  5.65e-03  5.00e-01
  74  6.27e+07  2.91e+07  3.88e+03  6.72e+07  3.70e-04  5.47e-03  5.00e-01
  75  6.27e+07  2.91e+07  3.87e+03  6.72e+07  3.57e-04  5.29e-03  5.00e-01
  76  6.27e+07  2.91e+07  3.87e+03  6.72e+07  3.45e-04  5.11e-03  5.00e-01
  77  6.27e+07  2.91e+07  3.86e+03  6.72e+07  3.33e-04  4.95e-03  5.00e-01
--------------------------------------------------------------------------
Solve time: 0.98 s

Reconstruct the image from the sparse representation.

imgr = np.sum(fft.fftconv(D, X, axes=(0, 1)), axis=2)
print("Reconstruction PSNR: %.2fdB\n" % metric.psnr(img, imgr))
Reconstruction PSNR: 45.50dB

Display representation and reconstructed image.

fig = plot.figure(figsize=(14, 14))
plot.subplot(2, 2, 1)
plot.imview(X[..., 0].squeeze(), title='Lowpass component', fig=fig)
plot.subplot(2, 2, 2)
plot.imview(np.sum(abs(X[..., 1:]), axis=2).squeeze(),
            cmap=plot.cm.Blues, title='Main representation', fig=fig)
plot.subplot(2, 2, 3)
plot.imview(imgr, title='Reconstructed image', fig=fig)
plot.subplot(2, 2, 4)
plot.imview(imgr - img, fltscl=True, title='Reconstruction difference',
            fig=fig)
fig.show()
../../_images/cbpdn_grd_cuda_17_0.png