CUDA Convolutional Sparse Coding

This example demonstrates the use of the interface to the CUDA CSC solver extension package, with a test for the availablity of a GPU that runs the Python version of the ConvBPDN solver if one is not available, or if the extension package is not installed.

from __future__ import print_function
from builtins import input

import pyfftw   # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np

from sporco import util
from sporco import signal
from sporco import fft
from sporco import metric
from sporco import plot
plot.config_notebook_plotting()
from sporco import cuda
from sporco.admm import cbpdn


# If running in a notebook, try to use wurlitzer so that output from the CUDA
# code will be properly captured in the notebook.
sys_pipes = util.notebook_system_output()

Load example image.

img = util.ExampleImages().image('barbara.png', scaled=True, gray=True,
                                 idxexp=np.s_[10:522, 100:612])

Highpass filter example image.

npd = 16
fltlmbd = 20
sl, sh = signal.tikhonov_filter(img, fltlmbd, npd)

Load dictionary.

D = util.convdicts()['G:12x12x36']

Set up admm.cbpdn.ConvBPDN options.

lmbda = 1e-2
opt = cbpdn.ConvBPDN.Options({'Verbose': True, 'MaxMainIter': 250,
                    'HighMemSolve': True, 'RelStopTol': 5e-3,
                    'AuxVarObj': False})

If GPU available, run CUDA ConvBPDN solver, otherwise run standard Python version.

if cuda.device_count() > 0:
    print('%s GPU found: running CUDA solver' % cuda.device_name())
    tm = util.Timer()
    with sys_pipes(), util.ContextTimer(tm):
        X = cuda.cbpdn(D, sh, lmbda, opt)
    t = tm.elapsed()
else:
    print('GPU not found: running Python solver')
    c = cbpdn.ConvBPDN(D, sh, lmbda, opt)
    X = c.solve().squeeze()
    t = c.timer.elapsed('solve')
print('Solve time: %.2f s' % t)
GeForce RTX 2080 Ti GPU found: running CUDA solver
Itn   Fnc       DFid      Regℓ1     r         s         ρ
----------------------------------------------------------------
   0  6.26e+04  6.25e+04  1.02e+04  6.47e-01  7.54e-01  1.50e+00
   1  6.36e+04  6.35e+04  8.59e+03  3.36e-01  5.78e-01  1.50e+00
   2  7.68e+04  7.67e+04  1.04e+04  3.20e-01  3.53e-01  1.11e+00
   3  7.73e+04  7.72e+04  8.21e+03  2.50e-01  2.95e-01  1.11e+00
   4  7.40e+04  7.40e+04  7.18e+03  2.21e-01  1.95e-01  1.00e+00
   5  7.55e+04  7.55e+04  6.99e+03  1.65e-01  1.48e-01  1.00e+00
   6  7.99e+04  7.98e+04  6.64e+03  1.28e-01  1.33e-01  1.00e+00
   7  8.40e+04  8.39e+04  6.30e+03  1.06e-01  1.11e-01  1.00e+00
   8  8.48e+04  8.48e+04  6.08e+03  8.85e-02  9.98e-02  1.00e+00
   9  8.45e+04  8.45e+04  5.76e+03  7.54e-02  8.71e-02  1.00e+00
  10  8.47e+04  8.47e+04  5.54e+03  6.90e-02  7.52e-02  9.06e-01
  11  8.64e+04  8.63e+04  5.46e+03  5.92e-02  6.68e-02  9.06e-01
  12  8.81e+04  8.81e+04  5.38e+03  5.15e-02  6.08e-02  9.06e-01
  13  8.95e+04  8.95e+04  5.32e+03  4.85e-02  5.57e-02  8.12e-01
  14  9.04e+04  9.03e+04  5.29e+03  4.61e-02  4.95e-02  7.38e-01
  15  9.12e+04  9.11e+04  5.21e+03  4.15e-02  4.47e-02  7.38e-01
  16  9.22e+04  9.21e+04  5.13e+03  3.75e-02  4.16e-02  7.38e-01
  17  9.33e+04  9.33e+04  5.06e+03  3.41e-02  3.87e-02  7.38e-01
  18  9.44e+04  9.44e+04  5.00e+03  3.12e-02  3.62e-02  7.38e-01
  19  9.54e+04  9.53e+04  4.96e+03  3.04e-02  3.41e-02  6.68e-01
  20  9.62e+04  9.62e+04  4.93e+03  2.83e-02  3.17e-02  6.68e-01
  21  9.70e+04  9.70e+04  4.90e+03  2.65e-02  2.95e-02  6.68e-01
  22  9.78e+04  9.78e+04  4.86e+03  2.48e-02  2.79e-02  6.68e-01
  23  9.86e+04  9.85e+04  4.83e+03  2.33e-02  2.65e-02  6.68e-01
  24  9.94e+04  9.93e+04  4.80e+03  2.30e-02  2.52e-02  6.09e-01
  25  1.00e+05  1.00e+05  4.79e+03  2.19e-02  2.36e-02  6.09e-01
  26  1.01e+05  1.01e+05  4.78e+03  2.08e-02  2.23e-02  6.09e-01
  27  1.01e+05  1.01e+05  4.76e+03  1.98e-02  2.12e-02  6.09e-01
  28  1.02e+05  1.02e+05  4.74e+03  1.88e-02  2.03e-02  6.09e-01
  29  1.02e+05  1.02e+05  4.71e+03  1.80e-02  1.95e-02  6.09e-01
  30  1.03e+05  1.03e+05  4.69e+03  1.72e-02  1.87e-02  6.09e-01
  31  1.03e+05  1.03e+05  4.67e+03  1.64e-02  1.80e-02  6.09e-01
  32  1.04e+05  1.03e+05  4.66e+03  1.58e-02  1.73e-02  6.09e-01
  33  1.04e+05  1.04e+05  4.64e+03  1.51e-02  1.67e-02  6.09e-01
  34  1.04e+05  1.04e+05  4.63e+03  1.45e-02  1.61e-02  6.09e-01
  35  1.04e+05  1.04e+05  4.61e+03  1.39e-02  1.55e-02  6.09e-01
  36  1.05e+05  1.05e+05  4.60e+03  1.34e-02  1.50e-02  6.09e-01
  37  1.05e+05  1.05e+05  4.59e+03  1.29e-02  1.45e-02  6.09e-01
  38  1.05e+05  1.05e+05  4.58e+03  1.25e-02  1.40e-02  6.09e-01
  39  1.06e+05  1.06e+05  4.57e+03  1.20e-02  1.35e-02  6.09e-01
  40  1.06e+05  1.06e+05  4.55e+03  1.16e-02  1.31e-02  6.09e-01
  41  1.06e+05  1.06e+05  4.54e+03  1.12e-02  1.27e-02  6.09e-01
  42  1.06e+05  1.06e+05  4.53e+03  1.08e-02  1.23e-02  6.09e-01
  43  1.06e+05  1.06e+05  4.52e+03  1.05e-02  1.19e-02  6.09e-01
  44  1.07e+05  1.07e+05  4.51e+03  1.01e-02  1.15e-02  6.09e-01
  45  1.07e+05  1.07e+05  4.51e+03  1.02e-02  1.11e-02  5.56e-01
  46  1.07e+05  1.07e+05  4.50e+03  9.94e-03  1.07e-02  5.56e-01
  47  1.07e+05  1.07e+05  4.50e+03  9.67e-03  1.03e-02  5.56e-01
  48  1.08e+05  1.08e+05  4.49e+03  9.39e-03  9.88e-03  5.56e-01
  49  1.08e+05  1.08e+05  4.49e+03  9.12e-03  9.56e-03  5.56e-01
  50  1.08e+05  1.08e+05  4.48e+03  8.85e-03  9.28e-03  5.56e-01
  51  1.08e+05  1.08e+05  4.47e+03  8.60e-03  9.01e-03  5.56e-01
  52  1.08e+05  1.08e+05  4.47e+03  8.35e-03  8.75e-03  5.56e-01
  53  1.08e+05  1.08e+05  4.46e+03  8.11e-03  8.51e-03  5.56e-01
  54  1.09e+05  1.09e+05  4.46e+03  7.88e-03  8.29e-03  5.56e-01
  55  1.09e+05  1.09e+05  4.45e+03  7.65e-03  8.07e-03  5.56e-01
  56  1.09e+05  1.09e+05  4.44e+03  7.44e-03  7.87e-03  5.56e-01
  57  1.09e+05  1.09e+05  4.44e+03  7.24e-03  7.67e-03  5.56e-01
  58  1.09e+05  1.09e+05  4.43e+03  7.04e-03  7.48e-03  5.56e-01
  59  1.09e+05  1.09e+05  4.42e+03  6.85e-03  7.28e-03  5.56e-01
  60  1.09e+05  1.09e+05  4.42e+03  6.67e-03  7.09e-03  5.56e-01
  61  1.09e+05  1.09e+05  4.42e+03  6.49e-03  6.91e-03  5.56e-01
  62  1.10e+05  1.10e+05  4.41e+03  6.32e-03  6.73e-03  5.56e-01
  63  1.10e+05  1.10e+05  4.41e+03  6.16e-03  6.57e-03  5.56e-01
  64  1.10e+05  1.10e+05  4.40e+03  6.00e-03  6.40e-03  5.56e-01
  65  1.10e+05  1.10e+05  4.40e+03  5.84e-03  6.24e-03  5.56e-01
  66  1.10e+05  1.10e+05  4.39e+03  5.69e-03  6.09e-03  5.56e-01
  67  1.10e+05  1.10e+05  4.39e+03  5.55e-03  5.94e-03  5.56e-01
  68  1.10e+05  1.10e+05  4.38e+03  5.41e-03  5.80e-03  5.56e-01
  69  1.10e+05  1.10e+05  4.38e+03  5.27e-03  5.66e-03  5.56e-01
  70  1.10e+05  1.10e+05  4.38e+03  5.14e-03  5.53e-03  5.56e-01
  71  1.10e+05  1.10e+05  4.37e+03  5.02e-03  5.40e-03  5.56e-01
  72  1.11e+05  1.11e+05  4.37e+03  4.89e-03  5.27e-03  5.56e-01
  73  1.11e+05  1.11e+05  4.37e+03  4.77e-03  5.14e-03  5.56e-01
  74  1.11e+05  1.11e+05  4.36e+03  4.66e-03  5.02e-03  5.56e-01
  75  1.11e+05  1.11e+05  4.36e+03  4.55e-03  4.90e-03  5.56e-01
----------------------------------------------------------------
Solve time: 1.04 s

Reconstruct the image from the sparse representation.

shr = np.sum(fft.fftconv(D, X, axes=(0, 1)), axis=2)
imgr = sl + shr
print("Reconstruction PSNR: %.2fdB\n" % metric.psnr(img, imgr))
Reconstruction PSNR: 44.12dB

Display representation and reconstructed image.

fig = plot.figure(figsize=(14, 14))
plot.subplot(2, 2, 1)
plot.imview(sl, title='Lowpass component', fig=fig)
plot.subplot(2, 2, 2)
plot.imview(np.sum(abs(X), axis=2).squeeze(),
            cmap=plot.cm.Blues, title='Main representation', fig=fig)
plot.subplot(2, 2, 3)
plot.imview(imgr, title='Reconstructed image', fig=fig)
plot.subplot(2, 2, 4)
plot.imview(imgr - img, fltscl=True, title='Reconstruction difference',
            fig=fig)
fig.show()
../../_images/cbpdn_cuda_15_0.png