CUDA Convolutional Sparse Coding with Gradient Term

This example demonstrates the use of the interface to the CUDA CSC solver extension package, with a test for the availablity of a GPU that runs the Python version of the ConvBPDNGradReg solver if one is not available, or if the extension package is not installed.

from __future__ import print_function
from builtins import input

import pyfftw   # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np

from sporco import util
from sporco import fft
from sporco import metric
from sporco import plot
plot.config_notebook_plotting()
from sporco import cuda
from sporco.admm import cbpdn

# If running in a notebook, try to use wurlitzer so that output from the CUDA
# code will be properly captured in the notebook.
sys_pipes = util.notebook_system_output()

Load example image.

img = util.ExampleImages().image('barbara.png', scaled=True, gray=True,
                                 idxexp=np.s_[10:522, 100:612])

Load main dictionary and prepend an impulse filter for lowpass component representation.

Db = util.convdicts()['G:12x12x36']
di = np.zeros(Db.shape[0:2] + (1,), dtype=np.float32)
di[0, 0] = 1
D = np.concatenate((di, Db), axis=2)

Set up weights for the \(\ell_1\) norm to disable regularization of the coefficient map corresponding to the impulse filter.

wl1 = np.ones((1,)*2 + (D.shape[2:]), dtype=np.float32)
wl1[..., 0] = 0.0

Set of weights for the \(\ell_2\) norm of the gradient to disable regularization of all coefficient maps except for the one corresponding to the impulse filter.

wgr = np.zeros((D.shape[2]), dtype=np.float32)
wgr[0] = 1.0

Set up admm.cbpdn.ConvBPDNGradReg options.

lmbda = 1e-2
mu = 5e-1
opt = cbpdn.ConvBPDNGradReg.Options({'Verbose': True, 'MaxMainIter': 250,
                    'HighMemSolve': True, 'RelStopTol': 5e-3,
                    'AuxVarObj': False, 'AutoRho': {'Enabled': False},
                    'rho': 0.5, 'L1Weight': wl1, 'GradWeight': wgr})

If GPU available, run CUDA ConvBPDNGradReg solver, otherwise run standard Python version.

if cuda.device_count() > 0:
    print('%s GPU found: running CUDA solver' % cuda.device_name())
    tm = util.Timer()
    with sys_pipes(), util.ContextTimer(tm):
        X = cuda.cbpdngrd(D, img, lmbda, mu, opt)
    t = tm.elapsed()
else:
    print('GPU not found: running Python solver')
    c = cbpdn.ConvBPDNGradReg(D, img, lmbda, mu, opt)
    X = c.solve().squeeze()
    t = c.timer.elapsed('solve')
print('Solve time: %.2f s' % t)
GPU not found: running Python solver
Itn   Fnc       DFid      Regℓ1     Regℓ2∇    r         s
----------------------------------------------------------------
   0  3.50e+03  3.02e+03  4.85e+04  3.62e-01  4.45e-01  9.48e+00
   1  5.78e+02  1.27e+02  4.50e+04  1.88e+00  1.19e-01  2.28e+00
   2  3.60e+02  9.87e+00  3.49e+04  3.06e+00  6.07e-02  7.70e-01
   3  2.86e+02  4.87e+00  2.79e+04  4.19e+00  4.56e-02  5.12e-01
   4  2.34e+02  3.36e+00  2.27e+04  5.51e+00  3.64e-02  4.01e-01
   5  2.00e+02  2.92e+00  1.93e+04  6.84e+00  3.00e-02  3.33e-01
   6  1.72e+02  2.79e+00  1.65e+04  8.18e+00  2.50e-02  2.85e-01
   7  1.51e+02  2.73e+00  1.43e+04  9.54e+00  2.09e-02  2.53e-01
   8  1.34e+02  2.67e+00  1.26e+04  1.09e+01  1.78e-02  2.24e-01
   9  1.20e+02  2.62e+00  1.12e+04  1.22e+01  1.52e-02  2.01e-01
  10  1.10e+02  2.57e+00  1.01e+04  1.35e+01  1.32e-02  1.80e-01
  11  1.02e+02  2.53e+00  9.22e+03  1.47e+01  1.16e-02  1.60e-01
  12  9.58e+01  2.51e+00  8.53e+03  1.59e+01  1.03e-02  1.43e-01
  13  9.05e+01  2.49e+00  7.96e+03  1.69e+01  9.15e-03  1.29e-01
  14  8.57e+01  2.49e+00  7.43e+03  1.78e+01  8.19e-03  1.18e-01
  15  8.17e+01  2.50e+00  6.99e+03  1.87e+01  7.37e-03  1.09e-01
  16  7.85e+01  2.50e+00  6.63e+03  1.94e+01  6.68e-03  1.00e-01
  17  7.62e+01  2.50e+00  6.37e+03  2.00e+01  6.10e-03  9.21e-02
  18  7.43e+01  2.50e+00  6.15e+03  2.06e+01  5.60e-03  8.43e-02
  19  7.25e+01  2.50e+00  5.95e+03  2.10e+01  5.16e-03  7.72e-02
  20  7.09e+01  2.51e+00  5.77e+03  2.14e+01  4.77e-03  7.11e-02
  21  6.94e+01  2.52e+00  5.60e+03  2.18e+01  4.41e-03  6.61e-02
  22  6.82e+01  2.53e+00  5.46e+03  2.21e+01  4.09e-03  6.18e-02
  23  6.70e+01  2.54e+00  5.33e+03  2.24e+01  3.81e-03  5.77e-02
  24  6.61e+01  2.55e+00  5.23e+03  2.26e+01  3.56e-03  5.36e-02
  25  6.52e+01  2.56e+00  5.13e+03  2.28e+01  3.33e-03  4.98e-02
  26  6.44e+01  2.57e+00  5.04e+03  2.30e+01  3.12e-03  4.64e-02
  27  6.37e+01  2.57e+00  4.96e+03  2.31e+01  2.93e-03  4.34e-02
  28  6.31e+01  2.58e+00  4.89e+03  2.32e+01  2.76e-03  4.06e-02
  29  6.25e+01  2.59e+00  4.82e+03  2.33e+01  2.60e-03  3.82e-02
  30  6.19e+01  2.60e+00  4.76e+03  2.34e+01  2.45e-03  3.60e-02
  31  6.14e+01  2.60e+00  4.70e+03  2.35e+01  2.32e-03  3.42e-02
  32  6.08e+01  2.61e+00  4.64e+03  2.36e+01  2.19e-03  3.27e-02
  33  6.04e+01  2.61e+00  4.59e+03  2.36e+01  2.07e-03  3.12e-02
  34  5.99e+01  2.62e+00  4.55e+03  2.37e+01  1.96e-03  2.95e-02
  35  5.96e+01  2.62e+00  4.51e+03  2.37e+01  1.86e-03  2.78e-02
  36  5.93e+01  2.63e+00  4.48e+03  2.38e+01  1.77e-03  2.61e-02
  37  5.89e+01  2.63e+00  4.44e+03  2.38e+01  1.69e-03  2.47e-02
  38  5.86e+01  2.63e+00  4.41e+03  2.38e+01  1.61e-03  2.34e-02
  39  5.83e+01  2.64e+00  4.38e+03  2.39e+01  1.53e-03  2.23e-02
  40  5.80e+01  2.64e+00  4.34e+03  2.39e+01  1.46e-03  2.13e-02
  41  5.77e+01  2.64e+00  4.31e+03  2.39e+01  1.39e-03  2.04e-02
  42  5.75e+01  2.65e+00  4.29e+03  2.39e+01  1.33e-03  1.95e-02
  43  5.72e+01  2.65e+00  4.26e+03  2.40e+01  1.27e-03  1.87e-02
  44  5.70e+01  2.65e+00  4.24e+03  2.40e+01  1.21e-03  1.78e-02
  45  5.68e+01  2.65e+00  4.22e+03  2.40e+01  1.16e-03  1.70e-02
  46  5.66e+01  2.66e+00  4.20e+03  2.40e+01  1.11e-03  1.62e-02
  47  5.64e+01  2.66e+00  4.18e+03  2.40e+01  1.06e-03  1.55e-02
  48  5.62e+01  2.66e+00  4.16e+03  2.40e+01  1.02e-03  1.49e-02
  49  5.61e+01  2.66e+00  4.14e+03  2.40e+01  9.72e-04  1.43e-02
  50  5.59e+01  2.66e+00  4.12e+03  2.40e+01  9.31e-04  1.37e-02
  51  5.57e+01  2.67e+00  4.10e+03  2.41e+01  8.91e-04  1.32e-02
  52  5.56e+01  2.67e+00  4.09e+03  2.41e+01  8.55e-04  1.27e-02
  53  5.54e+01  2.67e+00  4.07e+03  2.41e+01  8.20e-04  1.21e-02
  54  5.53e+01  2.67e+00  4.06e+03  2.41e+01  7.88e-04  1.16e-02
  55  5.52e+01  2.67e+00  4.05e+03  2.41e+01  7.57e-04  1.11e-02
  56  5.51e+01  2.67e+00  4.03e+03  2.41e+01  7.28e-04  1.06e-02
  57  5.50e+01  2.68e+00  4.02e+03  2.41e+01  6.99e-04  1.02e-02
  58  5.48e+01  2.68e+00  4.01e+03  2.41e+01  6.71e-04  9.83e-03
  59  5.47e+01  2.68e+00  4.00e+03  2.41e+01  6.45e-04  9.46e-03
  60  5.46e+01  2.68e+00  3.99e+03  2.41e+01  6.20e-04  9.11e-03
  61  5.45e+01  2.68e+00  3.98e+03  2.41e+01  5.97e-04  8.76e-03
  62  5.44e+01  2.68e+00  3.97e+03  2.41e+01  5.74e-04  8.42e-03
  63  5.43e+01  2.68e+00  3.96e+03  2.41e+01  5.53e-04  8.11e-03
  64  5.43e+01  2.68e+00  3.95e+03  2.41e+01  5.32e-04  7.83e-03
  65  5.42e+01  2.68e+00  3.94e+03  2.41e+01  5.12e-04  7.56e-03
  66  5.41e+01  2.68e+00  3.93e+03  2.41e+01  4.93e-04  7.29e-03
  67  5.40e+01  2.68e+00  3.93e+03  2.41e+01  4.75e-04  7.03e-03
  68  5.39e+01  2.69e+00  3.92e+03  2.41e+01  4.58e-04  6.77e-03
  69  5.39e+01  2.69e+00  3.91e+03  2.41e+01  4.42e-04  6.50e-03
  70  5.38e+01  2.69e+00  3.91e+03  2.41e+01  4.27e-04  6.26e-03
  71  5.38e+01  2.69e+00  3.90e+03  2.41e+01  4.12e-04  6.04e-03
  72  5.37e+01  2.69e+00  3.89e+03  2.41e+01  3.97e-04  5.84e-03
  73  5.36e+01  2.69e+00  3.89e+03  2.42e+01  3.83e-04  5.65e-03
  74  5.36e+01  2.69e+00  3.88e+03  2.42e+01  3.70e-04  5.47e-03
  75  5.35e+01  2.69e+00  3.87e+03  2.42e+01  3.57e-04  5.29e-03
  76  5.35e+01  2.69e+00  3.87e+03  2.42e+01  3.45e-04  5.11e-03
  77  5.34e+01  2.69e+00  3.86e+03  2.42e+01  3.33e-04  4.95e-03
----------------------------------------------------------------
Solve time: 48.24 s

Reconstruct the image from the sparse representation.

imgr = np.sum(fft.fftconv(D, X, axes=(0, 1)), axis=2)
print("Reconstruction PSNR: %.2fdB\n" % metric.psnr(img, imgr))
Reconstruction PSNR: 45.50dB

Display representation and reconstructed image.

fig = plot.figure(figsize=(14, 14))
plot.subplot(2, 2, 1)
plot.imview(X[..., 0].squeeze(), title='Lowpass component', fig=fig)
plot.subplot(2, 2, 2)
plot.imview(np.sum(abs(X[..., 1:]), axis=2).squeeze(),
            cmap=plot.cm.Blues, title='Main representation', fig=fig)
plot.subplot(2, 2, 3)
plot.imview(imgr, title='Reconstructed image', fig=fig)
plot.subplot(2, 2, 4)
plot.imview(imgr - img, fltscl=True, title='Reconstruction difference',
            fig=fig)
fig.show()
../../_images/cbpdn_grd_cuda_17_0.png