CUDA Convolutional Sparse Coding with Gradient Term

This example demonstrates the use of the interface to the CUDA CSC solver extension package, with a test for the availablity of a GPU that runs the Python version of the ConvBPDNGradReg solver if one is not available, or if the extension package is not installed.

from __future__ import print_function
from builtins import input
from builtins import range

import pyfftw   # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np

from sporco import util
from sporco import plot
plot.config_notebook_plotting()
from sporco import cuda
from sporco.admm import cbpdn
import sporco.linalg as spl
import sporco.metric as spm

# If running in a notebook, try to use wurlitzer so that output from the CUDA
# code will be properly captured in the notebook.
sys_pipes = util.notebook_system_output()

Load example image.

img = util.ExampleImages().image('barbara.png', scaled=True, gray=True,
                                 idxexp=np.s_[10:522, 100:612])

Load main dictionary and prepend an impulse filter for lowpass component representation.

Db = util.convdicts()['G:12x12x36']
di = np.zeros(Db.shape[0:2] + (1,), dtype=np.float32)
di[0, 0] = 1
D = np.concatenate((di, Db), axis=2)

Set up weights for the \(\ell_1\) norm to disable regularization of the coefficient map corresponding to the impulse filter.

wl1 = np.ones((1,)*2 + (D.shape[2:]), dtype=np.float32)
wl1[..., 0] = 0.0

Set of weights for the \(\ell_2\) norm of the gradient to disable regularization of all coefficient maps except for the one corresponding to the impulse filter.

wgr = np.zeros((D.shape[2]), dtype=np.float32)
wgr[0] = 1.0

Set up admm.cbpdn.ConvBPDNGradReg options.

lmbda = 1e-2
mu = 5e-1
opt = cbpdn.ConvBPDNGradReg.Options({'Verbose': True, 'MaxMainIter': 250,
                    'HighMemSolve': True, 'RelStopTol': 5e-3,
                    'AuxVarObj': False, 'AutoRho': {'Enabled': False},
                    'rho': 0.5, 'L1Weight': wl1, 'GradWeight': wgr})

If GPU available, run CUDA ConvBPDNGradReg solver, otherwise run standard Python version.

if cuda.device_count() > 0:
    print('%s GPU found: running CUDA solver' % cuda.device_name())
    tm = util.Timer()
    with sys_pipes(), util.ContextTimer(tm):
        X = cuda.cbpdngrd(D, img, lmbda, mu, opt)
    t = tm.elapsed()
else:
    print('GPU not found: running Python solver')
    c = cbpdn.ConvBPDNGradReg(D, img, lmbda, mu, opt)
    X = c.solve().squeeze()
    t = c.timer.elapsed('solve')
print('Solve time: %.2f s' % t)
Tesla K40c GPU found: running CUDA solver
Itn   Fnc       DFid      Regℓ1     Regℓ2∇     r         s         ρ
--------------------------------------------------------------------------
   0  6.80e+03  6.30e+03  5.00e+04  4.21e-01  4.45e-01  9.61e+00  5.00e-01
   1  7.23e+02  2.56e+02  4.66e+04  2.10e+00  1.19e-01  2.31e+00  5.00e-01
   2  3.78e+02  1.55e+01  3.61e+04  3.32e+00  6.03e-02  7.76e-01  5.00e-01
   3  2.97e+02  5.75e+00  2.89e+04  4.52e+00  4.53e-02  5.17e-01  5.00e-01
   4  2.42e+02  3.56e+00  2.36e+04  5.92e+00  3.63e-02  4.06e-01  5.00e-01
   5  2.07e+02  3.05e+00  2.00e+04  7.31e+00  2.99e-02  3.37e-01  5.00e-01
   6  1.78e+02  2.91e+00  1.71e+04  8.73e+00  2.48e-02  2.88e-01  5.00e-01
   7  1.56e+02  2.84e+00  1.49e+04  1.02e+01  2.08e-02  2.56e-01  5.00e-01
   8  1.39e+02  2.76e+00  1.30e+04  1.16e+01  1.76e-02  2.27e-01  5.00e-01
   9  1.25e+02  2.71e+00  1.16e+04  1.30e+01  1.51e-02  2.04e-01  5.00e-01
  10  1.14e+02  2.68e+00  1.04e+04  1.43e+01  1.31e-02  1.83e-01  5.00e-01
  11  1.06e+02  2.63e+00  9.52e+03  1.56e+01  1.15e-02  1.63e-01  5.00e-01
  12  9.91e+01  2.60e+00  8.81e+03  1.68e+01  1.02e-02  1.45e-01  5.00e-01
  13  9.37e+01  2.58e+00  8.22e+03  1.79e+01  9.05e-03  1.31e-01  5.00e-01
  14  8.88e+01  2.57e+00  7.68e+03  1.89e+01  8.10e-03  1.19e-01  5.00e-01
  15  8.47e+01  2.58e+00  7.22e+03  1.98e+01  7.28e-03  1.10e-01  5.00e-01
  16  8.14e+01  2.59e+00  6.85e+03  2.06e+01  6.59e-03  1.02e-01  5.00e-01
  17  7.90e+01  2.59e+00  6.58e+03  2.13e+01  6.01e-03  9.33e-02  5.00e-01
  18  7.70e+01  2.58e+00  6.34e+03  2.19e+01  5.52e-03  8.55e-02  5.00e-01
  19  7.52e+01  2.57e+00  6.14e+03  2.24e+01  5.09e-03  7.82e-02  5.00e-01
  20  7.35e+01  2.58e+00  5.95e+03  2.29e+01  4.69e-03  7.19e-02  5.00e-01
  21  7.21e+01  2.59e+00  5.78e+03  2.32e+01  4.34e-03  6.66e-02  5.00e-01
  22  7.08e+01  2.62e+00  5.64e+03  2.36e+01  4.03e-03  6.22e-02  5.00e-01
  23  6.96e+01  2.60e+00  5.50e+03  2.39e+01  3.74e-03  5.82e-02  5.00e-01
  24  6.86e+01  2.62e+00  5.39e+03  2.41e+01  3.49e-03  5.41e-02  5.00e-01
  25  6.78e+01  2.63e+00  5.29e+03  2.44e+01  3.27e-03  5.02e-02  5.00e-01
  26  6.70e+01  2.64e+00  5.20e+03  2.46e+01  3.06e-03  4.66e-02  5.00e-01
  27  6.62e+01  2.65e+00  5.12e+03  2.47e+01  2.88e-03  4.35e-02  5.00e-01
  28  6.56e+01  2.66e+00  5.05e+03  2.49e+01  2.71e-03  4.08e-02  5.00e-01
  29  6.50e+01  2.65e+00  4.98e+03  2.50e+01  2.55e-03  3.83e-02  5.00e-01
  30  6.44e+01  2.67e+00  4.92e+03  2.51e+01  2.40e-03  3.61e-02  5.00e-01
  31  6.38e+01  2.66e+00  4.86e+03  2.52e+01  2.27e-03  3.42e-02  5.00e-01
  32  6.33e+01  2.69e+00  4.80e+03  2.53e+01  2.14e-03  3.27e-02  5.00e-01
  33  6.28e+01  2.69e+00  4.75e+03  2.53e+01  2.02e-03  3.12e-02  5.00e-01
  34  6.24e+01  2.69e+00  4.70e+03  2.54e+01  1.92e-03  2.95e-02  5.00e-01
  35  6.20e+01  2.68e+00  4.66e+03  2.54e+01  1.82e-03  2.78e-02  5.00e-01
  36  6.17e+01  2.71e+00  4.62e+03  2.55e+01  1.73e-03  2.62e-02  5.00e-01
  37  6.14e+01  2.70e+00  4.59e+03  2.55e+01  1.65e-03  2.47e-02  5.00e-01
  38  6.11e+01  2.70e+00  4.56e+03  2.56e+01  1.57e-03  2.35e-02  5.00e-01
  39  6.42e+01  6.13e+00  4.52e+03  2.56e+01  1.49e-03  2.24e-02  5.00e-01
  40  6.04e+01  2.70e+00  4.49e+03  2.56e+01  1.43e-03  2.13e-02  5.00e-01
  41  6.02e+01  2.72e+00  4.46e+03  2.57e+01  1.36e-03  2.04e-02  5.00e-01
  42  5.99e+01  2.71e+00  4.44e+03  2.57e+01  1.30e-03  1.95e-02  5.00e-01
  43  5.97e+01  2.74e+00  4.41e+03  2.57e+01  1.24e-03  1.87e-02  5.00e-01
  44  5.94e+01  2.71e+00  4.39e+03  2.57e+01  1.18e-03  1.78e-02  5.00e-01
  45  5.92e+01  2.71e+00  4.36e+03  2.57e+01  1.13e-03  1.71e-02  5.00e-01
  46  5.90e+01  2.72e+00  4.34e+03  2.58e+01  1.08e-03  1.62e-02  5.00e-01
  47  5.89e+01  2.73e+00  4.32e+03  2.58e+01  1.04e-03  1.55e-02  5.00e-01
  48  5.87e+01  2.73e+00  4.30e+03  2.58e+01  9.93e-04  1.49e-02  5.00e-01
  49  5.85e+01  2.73e+00  4.29e+03  2.58e+01  9.50e-04  1.43e-02  5.00e-01
  50  5.83e+01  2.73e+00  4.27e+03  2.58e+01  9.11e-04  1.37e-02  5.00e-01
  51  6.11e+01  5.69e+00  4.25e+03  2.58e+01  8.71e-04  1.32e-02  5.00e-01
  52  5.80e+01  2.74e+00  4.23e+03  2.58e+01  8.35e-04  1.27e-02  5.00e-01
  53  5.78e+01  2.74e+00  4.22e+03  2.58e+01  8.02e-04  1.21e-02  5.00e-01
  54  5.77e+01  2.74e+00  4.20e+03  2.58e+01  7.70e-04  1.16e-02  5.00e-01
  55  6.05e+01  5.67e+00  4.19e+03  2.59e+01  7.40e-04  1.11e-02  5.00e-01
  56  5.75e+01  2.74e+00  4.18e+03  2.59e+01  7.11e-04  1.07e-02  5.00e-01
  57  5.74e+01  2.74e+00  4.17e+03  2.59e+01  6.83e-04  1.03e-02  5.00e-01
  58  5.73e+01  2.77e+00  4.16e+03  2.59e+01  6.56e-04  9.87e-03  5.00e-01
  59  5.71e+01  2.75e+00  4.15e+03  2.59e+01  6.31e-04  9.51e-03  5.00e-01
  60  5.70e+01  2.75e+00  4.13e+03  2.59e+01  6.07e-04  9.15e-03  5.00e-01
  61  5.69e+01  2.75e+00  4.12e+03  2.59e+01  5.84e-04  8.80e-03  5.00e-01
  62  5.68e+01  2.74e+00  4.11e+03  2.59e+01  5.63e-04  8.46e-03  5.00e-01
  63  5.68e+01  2.75e+00  4.10e+03  2.59e+01  5.41e-04  8.16e-03  5.00e-01
  64  5.67e+01  2.75e+00  4.10e+03  2.59e+01  5.21e-04  7.86e-03  5.00e-01
  65  5.66e+01  2.74e+00  4.09e+03  2.59e+01  5.02e-04  7.59e-03  5.00e-01
  66  5.65e+01  2.75e+00  4.08e+03  2.59e+01  4.83e-04  7.33e-03  5.00e-01
  67  5.64e+01  2.75e+00  4.07e+03  2.59e+01  4.66e-04  7.06e-03  5.00e-01
  68  5.63e+01  2.75e+00  4.06e+03  2.59e+01  4.49e-04  6.80e-03  5.00e-01
  69  5.63e+01  2.77e+00  4.06e+03  2.59e+01  4.33e-04  6.55e-03  5.00e-01
  70  5.62e+01  2.76e+00  4.05e+03  2.59e+01  4.19e-04  6.30e-03  5.00e-01
  71  5.62e+01  2.75e+00  4.04e+03  2.59e+01  4.03e-04  6.08e-03  5.00e-01
  72  5.61e+01  2.75e+00  4.04e+03  2.59e+01  3.89e-04  5.88e-03  5.00e-01
  73  5.60e+01  2.76e+00  4.03e+03  2.59e+01  3.76e-04  5.69e-03  5.00e-01
  74  5.60e+01  2.76e+00  4.02e+03  2.59e+01  3.63e-04  5.51e-03  5.00e-01
  75  5.59e+01  2.76e+00  4.02e+03  2.59e+01  3.50e-04  5.33e-03  5.00e-01
  76  5.59e+01  2.76e+00  4.01e+03  2.59e+01  3.39e-04  5.15e-03  5.00e-01
  77  5.58e+01  2.76e+00  4.01e+03  2.59e+01  3.27e-04  4.98e-03  5.00e-01
--------------------------------------------------------------------------
Solve time: 1.61 s

Reconstruct the image from the sparse representation.

imgr = np.sum(spl.fftconv(D, X), axis=2)
print("Reconstruction PSNR: %.2fdB\n" % spm.psnr(img, imgr))
Reconstruction PSNR: 45.66dB

Display representation and reconstructed image.

fig = plot.figure(figsize=(14, 14))
plot.subplot(2, 2, 1)
plot.imview(X[..., 0].squeeze(), title='Lowpass component', fig=fig)
plot.subplot(2, 2, 2)
plot.imview(np.sum(abs(X[..., 1:]), axis=2).squeeze(),
            cmap=plot.cm.Blues, title='Main representation', fig=fig)
plot.subplot(2, 2, 3)
plot.imview(imgr, title='Reconstructed image', fig=fig)
plot.subplot(2, 2, 4)
plot.imview(imgr - img, fltscl=True, title='Reconstruction difference',
            fig=fig)
fig.show()
../../_images/cbpdn_grd_cuda_17_0.png