CUDA Convolutional Sparse Coding

This example demonstrates the use of the interface to the CUDA CSC solver extension package, with a test for the availablity of a GPU that runs the Python version of the ConvBPDN solver if one is not available, or if the extension package is not installed.

from __future__ import print_function
from builtins import input
from builtins import range

import pyfftw   # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np

from sporco import util
from sporco import plot
plot.config_notebook_plotting()
from sporco import cuda
from sporco.admm import cbpdn
import sporco.linalg as spl
import sporco.metric as spm

# If running in a notebook, try to use wurlitzer so that output from the CUDA
# code will be properly captured in the notebook.
sys_pipes = util.notebook_system_output()

Load example image.

img = util.ExampleImages().image('barbara.png', scaled=True, gray=True,
                                 idxexp=np.s_[10:522, 100:612])

Highpass filter example image.

npd = 16
fltlmbd = 20
sl, sh = util.tikhonov_filter(img, fltlmbd, npd)

Load dictionary.

D = util.convdicts()['G:12x12x36']

Set up admm.cbpdn.ConvBPDN options.

lmbda = 1e-2
opt = cbpdn.ConvBPDN.Options({'Verbose': True, 'MaxMainIter': 250,
                    'HighMemSolve': True, 'RelStopTol': 5e-3,
                    'AuxVarObj': False})

If GPU available, run CUDA ConvBPDN solver, otherwise run standard Python version.

if cuda.device_count() > 0:
    print('%s GPU found: running CUDA solver' % cuda.device_name())
    tm = util.Timer()
    with sys_pipes(), util.ContextTimer(tm):
        X = cuda.cbpdn(D, sh, lmbda, opt)
    t = tm.elapsed()
else:
    print('GPU not found: running Python solver')
    c = cbpdn.ConvBPDN(D, sh, lmbda, opt)
    X = c.solve().squeeze()
    t = c.timer.elapsed('solve')
print('Solve time: %.2f s' % t)
Tesla K40c GPU found: running CUDA solver
Itn   Fnc       DFid      Regℓ1     r         s         ρ
----------------------------------------------------------------
   0  1.07e+02  2.61e+00  1.05e+04  6.43e-01  7.75e-01  1.50e+00
   1  9.36e+01  4.78e+00  8.88e+03  3.34e-01  5.82e-01  1.50e+00
   2  1.11e+02  3.79e+00  1.07e+04  3.21e-01  3.54e-01  1.11e+00
   3  8.82e+01  3.66e+00  8.46e+03  2.50e-01  2.96e-01  1.11e+00
   4  7.80e+01  3.86e+00  7.41e+03  2.20e-01  1.94e-01  9.91e-01
   5  7.61e+01  4.02e+00  7.21e+03  1.64e-01  1.50e-01  9.91e-01
   6  7.26e+01  4.02e+00  6.86e+03  1.28e-01  1.33e-01  9.91e-01
   7  6.89e+01  3.94e+00  6.50e+03  1.05e-01  1.12e-01  9.91e-01
   8  6.65e+01  3.88e+00  6.26e+03  8.82e-02  1.00e-01  9.91e-01
   9  6.40e+01  3.89e+00  6.01e+03  8.00e-02  8.61e-02  9.04e-01
  10  6.14e+01  3.93e+00  5.74e+03  6.83e-02  7.38e-02  9.04e-01
  11  6.03e+01  3.96e+00  5.64e+03  5.83e-02  6.77e-02  9.04e-01
  12  5.99e+01  3.95e+00  5.60e+03  5.41e-02  6.09e-02  8.17e-01
  13  5.90e+01  3.91e+00  5.51e+03  4.79e-02  5.46e-02  8.17e-01
  14  5.83e+01  3.89e+00  5.44e+03  4.50e-02  5.01e-02  7.45e-01
  15  5.74e+01  3.89e+00  5.35e+03  4.05e-02  4.54e-02  7.45e-01
  16  5.66e+01  3.89e+00  5.27e+03  3.66e-02  4.20e-02  7.45e-01
  17  5.61e+01  3.89e+00  5.22e+03  3.52e-02  3.90e-02  6.77e-01
  18  5.58e+01  3.89e+00  5.19e+03  3.23e-02  3.58e-02  6.77e-01
  19  5.53e+01  3.88e+00  5.14e+03  2.98e-02  3.37e-02  6.77e-01
  20  5.48e+01  3.88e+00  5.09e+03  2.77e-02  3.18e-02  6.77e-01
  21  5.44e+01  3.88e+00  5.05e+03  2.71e-02  2.97e-02  6.16e-01
  22  5.41e+01  3.88e+00  5.02e+03  2.55e-02  2.77e-02  6.16e-01
  23  5.39e+01  3.88e+00  5.00e+03  2.40e-02  2.61e-02  6.16e-01
  24  5.37e+01  3.88e+00  4.98e+03  2.26e-02  2.49e-02  6.16e-01
  25  5.34e+01  3.88e+00  4.95e+03  2.14e-02  2.37e-02  6.16e-01
  26  5.32e+01  3.88e+00  4.93e+03  2.03e-02  2.26e-02  6.16e-01
  27  5.29e+01  3.88e+00  4.90e+03  1.93e-02  2.16e-02  6.16e-01
  28  5.26e+01  3.88e+00  4.88e+03  1.84e-02  2.07e-02  6.16e-01
  29  5.24e+01  3.88e+00  4.85e+03  1.76e-02  1.99e-02  6.16e-01
  30  5.22e+01  3.88e+00  4.84e+03  1.68e-02  1.91e-02  6.16e-01
  31  5.21e+01  3.88e+00  4.82e+03  1.61e-02  1.83e-02  6.16e-01
  32  5.19e+01  3.88e+00  4.81e+03  1.54e-02  1.77e-02  6.16e-01
  33  5.18e+01  3.88e+00  4.79e+03  1.55e-02  1.70e-02  5.60e-01
  34  5.17e+01  3.88e+00  4.78e+03  1.50e-02  1.62e-02  5.60e-01
  35  5.16e+01  3.88e+00  4.78e+03  1.44e-02  1.54e-02  5.60e-01
  36  5.15e+01  3.88e+00  4.77e+03  1.39e-02  1.48e-02  5.60e-01
  37  5.14e+01  3.88e+00  4.75e+03  1.34e-02  1.43e-02  5.60e-01
  38  5.13e+01  3.88e+00  4.74e+03  1.29e-02  1.38e-02  5.60e-01
  39  5.12e+01  3.88e+00  4.73e+03  1.25e-02  1.33e-02  5.60e-01
  40  5.11e+01  3.88e+00  4.72e+03  1.20e-02  1.29e-02  5.60e-01
  41  5.09e+01  3.88e+00  4.71e+03  1.16e-02  1.25e-02  5.60e-01
  42  5.08e+01  3.89e+00  4.69e+03  1.12e-02  1.21e-02  5.60e-01
  43  5.07e+01  3.89e+00  4.68e+03  1.09e-02  1.17e-02  5.60e-01
  44  5.06e+01  3.89e+00  4.68e+03  1.05e-02  1.14e-02  5.60e-01
  45  5.06e+01  3.89e+00  4.67e+03  1.02e-02  1.10e-02  5.60e-01
  46  5.05e+01  3.89e+00  4.66e+03  9.88e-03  1.07e-02  5.60e-01
  47  5.04e+01  3.89e+00  4.65e+03  9.58e-03  1.03e-02  5.60e-01
  48  5.03e+01  3.89e+00  4.64e+03  9.27e-03  1.00e-02  5.60e-01
  49  5.02e+01  3.89e+00  4.63e+03  8.99e-03  9.75e-03  5.60e-01
  50  5.02e+01  3.89e+00  4.63e+03  8.73e-03  9.47e-03  5.60e-01
  51  5.01e+01  3.89e+00  4.62e+03  8.47e-03  9.20e-03  5.60e-01
  52  5.00e+01  3.89e+00  4.61e+03  8.22e-03  8.95e-03  5.60e-01
  53  5.00e+01  3.89e+00  4.61e+03  8.00e-03  8.71e-03  5.60e-01
  54  4.99e+01  3.89e+00  4.60e+03  7.77e-03  8.50e-03  5.60e-01
  55  4.98e+01  3.89e+00  4.59e+03  7.54e-03  8.26e-03  5.60e-01
  56  4.98e+01  3.89e+00  4.59e+03  7.33e-03  8.05e-03  5.60e-01
  57  4.97e+01  3.89e+00  4.58e+03  7.13e-03  7.86e-03  5.60e-01
  58  4.96e+01  3.89e+00  4.57e+03  6.93e-03  7.67e-03  5.60e-01
  59  4.96e+01  3.89e+00  4.57e+03  6.74e-03  7.47e-03  5.60e-01
  60  4.95e+01  3.89e+00  4.56e+03  6.57e-03  7.27e-03  5.60e-01
  61  4.95e+01  3.89e+00  4.56e+03  6.40e-03  7.09e-03  5.60e-01
  62  4.94e+01  3.89e+00  4.55e+03  6.23e-03  6.90e-03  5.60e-01
  63  4.94e+01  3.89e+00  4.55e+03  6.08e-03  6.74e-03  5.60e-01
  64  4.93e+01  3.89e+00  4.55e+03  5.90e-03  6.56e-03  5.60e-01
  65  4.93e+01  3.89e+00  4.54e+03  5.75e-03  6.39e-03  5.60e-01
  66  4.93e+01  3.89e+00  4.54e+03  5.61e-03  6.24e-03  5.60e-01
  67  4.92e+01  3.89e+00  4.53e+03  5.47e-03  6.09e-03  5.60e-01
  68  4.92e+01  3.89e+00  4.53e+03  5.34e-03  5.94e-03  5.60e-01
  69  4.91e+01  3.89e+00  4.52e+03  5.19e-03  5.80e-03  5.60e-01
  70  4.91e+01  3.89e+00  4.52e+03  5.07e-03  5.67e-03  5.60e-01
  71  4.91e+01  3.89e+00  4.52e+03  4.94e-03  5.52e-03  5.60e-01
  72  4.90e+01  3.89e+00  4.51e+03  4.81e-03  5.39e-03  5.60e-01
  73  4.90e+01  3.89e+00  4.51e+03  4.70e-03  5.25e-03  5.60e-01
  74  4.90e+01  3.90e+00  4.51e+03  4.59e-03  5.13e-03  5.60e-01
  75  4.89e+01  3.90e+00  4.50e+03  4.48e-03  5.02e-03  5.60e-01
  76  4.89e+01  3.90e+00  4.50e+03  4.37e-03  4.90e-03  5.60e-01
----------------------------------------------------------------
Solve time: 1.60 s

Reconstruct the image from the sparse representation.

shr = np.sum(spl.fftconv(D, X), axis=2)
imgr = sl + shr
print("Reconstruction PSNR: %.2fdB\n" % spm.psnr(img, imgr))
Reconstruction PSNR: 44.30dB

Display representation and reconstructed image.

fig = plot.figure(figsize=(14, 14))
plot.subplot(2, 2, 1)
plot.imview(sl, title='Lowpass component', fig=fig)
plot.subplot(2, 2, 2)
plot.imview(np.sum(abs(X), axis=2).squeeze(),
            cmap=plot.cm.Blues, title='Main representation', fig=fig)
plot.subplot(2, 2, 3)
plot.imview(imgr, title='Reconstructed image', fig=fig)
plot.subplot(2, 2, 4)
plot.imview(imgr - img, fltscl=True, title='Reconstruction difference',
            fig=fig)
fig.show()
../../_images/cbpdn_cuda_15_0.png