Problem Sheet 2

In [6]:
import numpy as np
from scipy.sparse.linalg import svds

# pywt for wavelet transform implementation
import pywt
import pywt.data

# Wrapper function around pywt transform (to get vector as an output)
from func.transforms2d import transform_wavelet2

# matplotlib for plotiing
import matplotlib.pyplot as plt
%matplotlib inline

Sparse representation

Techninques for finding and using sparse representations are widely used in image processing (e.g. lossy image and lossy video compression)

A simple demonstration of how the cameraman image $y\in R^{512^2}$ can be compressed using Daubechies 4 wavelet transform:

In [24]:
x = pywt.data.camera()
x = x / np.linalg.norm(image, 'fro')
plt.imshow(x, cmap ='gray')
plt.show()
In [25]:
Psi = transform_wavelet2('Db4', mode = 'per', level = 5)
x = Psi.forward(image)
plt.semilogy(np.sort(np.abs(x))[::-1], linewidth=2)
plt.show()

Construct $\Psi$ to be a wavelet transform (Daubechies 4 order). Apply $x = \Psi y$. Now $x\in R^{512^2}$ is a vector of wavelet coefficients.

Notice in the figure above fast decay of coefficient values. This suggests that most of the information content of the image is contained only in few coefficients.

Keep only $5\%$ of the largest coefficients (in absolute value). Transform back into the image domain and observe the image. $$ x_k = HT(x, 5\%)\\ y_k = \Psi^*(x_k) $$

In [19]:
def thresh_hard_sparse(x, k):
    """
    Keep only k largest entries of x and return their indices.
    Parameters
    ----------
    x : numpy array
        Numpy array to be thresholded
    k : int
        Number of largest entries in absolute value to keep
    Notes
    """
    _x = x.copy()
    ind = np.argpartition(abs(_x), -k, axis=None)[-k:]
    ind = np.unravel_index(ind, _x.shape)
    ind_del = np.ones(_x.shape, dtype=bool)
    ind_del[ind] = False
    _x[ind_del] = 0
    return ind, _x
In [26]:
rho = 0.05
_, x_k = thresh_hard_sparse(x, round(rho*512**2))
image_k = Psi.backward(x_k)

fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(14, 3))
ax1.semilogy(np.sort(np.abs(x))[::-1], linewidth=2)
ax2.semilogy(np.sort(np.abs(x_k))[::-1], linewidth=2)
plt.show()

fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(16, 12))
ax1.imshow(image, cmap ='gray')
ax1.set_title('Cameraman')

ax2.imshow(image_k, cmap ='gray')
ax2.set_title('Cameraman thresholded')
plt.show()

Normalized Iterative Hard Thresholding (NIHT)

First implement NIHT (Blumensath & Davies 2010) for solving underdetermined linear system with sparse constraint:

$$ \min_{x} \| Ax - y \|_2^2, \qquad \text{s.t.}\quad \|x\|_0\leq k, $$

where $A\in R^{m\times n}$, $x\in R^n$, and $y\in R^{m}$, (where $n\geq m \geq k$).

Note: Some of you who took C6.1 Numerical Linear Algebra last term and know of Conjugate Gradient (CG) method for solving linear systems (Hestenes & Stiefel 1952), there is an equivalent method for solving linear underdetermined systems with sparsity constraints CGIHT (Blanchard et al.2015) which has faster convergence rates.

In [27]:
def support_projection(x, ind):
    """
    Keeps only coefficients at specified indices, setting others to zero.
    ----------
    x : numpy array
        Numpy array to be projected
    ind : int
        where to keep entries keep
    """
    _x = x.copy()
    ind_del = np.ones(x.shape, dtype=bool)
    ind_del[ind] = False
    _x[ind_del] = 0
    return x
In [28]:
def niht(A, y, k, tol = 1e-4, MAX_ITER = 100):
    """
    Normalized Iterative Hard Thresholding solving underdetermined system
    Ax = b,
    Parameters
    ----------
    A : numpy array
        Underdetermined matrix
    y : numpy array
        Right-hand side vector
    k : int
        Sparsity constraint on x
    Notes
    """
    error = np.zeros(MAX_ITER)
    w = A.T.dot(y)
    T_k, x = thresh_hard_sparse(w, k)
    error[0] = np.linalg.norm(A.dot(x) - y)/np.linalg.norm(y)  
    # Iterative process
    l = 2
    not_finished = True
    while not_finished:
        r = A.T.dot(y-A.dot(x))
        r_proj = support_projection(r, T_k)
        a = np.linalg.norm(r_proj)**2
        b = np.linalg.norm(A.dot(r_proj))**2
        alpha = a/b
        w = x + alpha * r
        T_k, x = thresh_hard_sparse(w, k)
        error[l-1] = np.linalg.norm(A.dot(x) - y)/np.linalg.norm(y) 
        not_finished = (l < MAX_ITER) and (error[l-1] >= tol)
        l = l + 1
    return (x, error[:(l-1)])

Test example of NIHT (Gaussian random matrix)

Here we test whether NIHT solver works for a generated toy problem. Consider a matrix $A\in R^{200\times 500}$ with i.i.d. entries sampled from a Gaussian distribution. Generate a random $k$-sparse $x\in R^{500}$, $\|x\|_0 = k$ and observe $y = Ax$. We run NIHT to test whether we can recover $x$ only from knowing $y$ and $A$.

In [29]:
m = 200
n= 500
k = 10
x = np.random.randn(n, 1)
A = np.random.randn(m, n) / np.sqrt(m)
ind, x = thresh_hard_sparse(x, k)
y = np.matmul(A, x, out=None)
x, error = niht(A, y, k, tol = 1e-4, MAX_ITER = 100)
plt.plot(error)
plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('Relative error')
plt.show()

Implementing K-SVD

Basic implementation of K-SVD (Aharon, Elad, Bruckstein, 2006). Solve the following dictionary learning problem:

$$ \min_{D,X} \| Y - DX \|_F^2, \qquad \text{s.t.} \quad \|x_i\|_0\leq T, $$

where $Y\in R^{n\times N}$, $D\in R^{n\times K}$, $X\in R^{K\times N}$, with $N \geq n \geq K$.

  • $N$ is number of samples,
  • $n$ is dimension of each sample,
  • $K$ is number of all dictionary elements we wish to recover,
  • $T$ is how many dictionary elements are allowed to be used to approximate every sample.

In the class, we discussed Label Consistent KSVD (LC-KSVD) and Discriminative KSVD (D-KSVD) Label Consistent K-SVD: Learning a Discriminative Dictionary for Recognition (Jiang et al. 2013) On the Equivalence of the LC-KSVD and the D-KSVD Algorithms (Kviatkovsky et al. 2016)

There is work on analysing convolution neural networks as a sequence of sparse dictionary learning problems. Convolutional Neural Networks Analyzed via Convolutional Sparse Coding (Papyan et al. 2016)

In [61]:
def ksvd(Y,K,T, tol = 1e-4, MAX_ITER = 100, D_init = np.array([])):
    """
    Naive implementation of K-SVD, solving
    Y \approx D X,
    Parameters
    ----------
    Y : numpy array
        Matrix of data samples
    K : int
        Number of dictionary elements
    T : int
        Sparsity constraint on how many dictionary elements to be used for every sample
    Notes
    """
    error = np.zeros(MAX_ITER)
    # Get sizes of the problem
    n, N = Y.shape
    # Gaussian random initialization for D
    # D = np.random.randn(n, K) / np.sqrt(n)
    # Initialize dictionary with data
    if not D_init.size:
        D = Y[:,:K] / np.linalg.norm(Y[:,:K], axis=0)
    else:
        D = D_init.copy()
    X = np.zeros((K,N))
    # Iterative process
    J = 1
    not_finished = True
    while not_finished:
        # Sparse Coding Stage
        for i in range(N):
            #print(D.shape, Y[:,i].shape, Y.shape)
            x_tmp,_ = niht(D, Y[:,i], T, tol = 1e-7, MAX_ITER = 80)
            X[:,i] = x_tmp

        error[J-1] = np.linalg.norm(Y - D @ X, ord='fro')/np.linalg.norm(Y, ord='fro')
        print(error[J-1])
        # Codebook Update Stage
        for k in range(K):
            w_k = np.where(X[k,:] != 0)[0]
            if w_k.size > 1:
                ind = np.ones((K,), bool)
                ind[k] = False
                E_k = Y - D[:,ind] @ X[ind,:]
                u, s, vt = svds(E_k[:,w_k], k = 1)
                D[:,k] = u.T
                X[k,w_k] = s[0] * vt
        not_finished = (J < MAX_ITER) and (error[J-1] >= tol)
        J = J + 1
    return (D,X,error)

Test on MNIST

Warning: Takes a while to run. Setting $N$ (number of images used) and $K$ (number of dictionary elements) lower makes for faster computation. Making the sparse constraint $T$ higher allows for less sparse representations, resulting in lower error (at the expense of more dense representations).

In [ ]:
# Load MNIST using torchvision.datasets
from torchvision import datasets
data = datasets.MNIST('data', train=True, download=True)
X_numpy = data.train_data.numpy()/255
y_numpy = data.train_labels.numpy()

def draw_MNIST(image, label = ''):
    # Take a numpy array of 784 entries
    plt.imshow(image.reshape(28,28), cmap='gray')
    plt.title(label)
    plt.show()

def onehot(integer_labels):
    #Return matrix whose rows are onehot encodings of integers.
    onehotL = np.zeros((len(integer_labels), len(np.unique(integer_labels))), dtype='uint8')
    onehotL[np.arange(len(integer_labels)), integer_labels] = 1
    return onehotL

X_numpy = X_numpy.reshape(60000,-1)
y_numpy = onehot(y_numpy)
print(X_numpy.shape, y_numpy.shape)
In [95]:
N = 1500
Y = X_numpy[:N].T
D, X, error = ksvd(Y, K = 28**2, T = 15, MAX_ITER = 40, D_init = D)
0.20131201576247906
0.18764997515215973
0.18235121914202326
0.17974241414855793
0.17775234444953347
0.17609870327456612
0.17561183677189388
0.17508599972548536
0.17445872694157002
0.17409143954147502
0.1734174162407598
0.17273648051867208
0.17233848856695883
0.1722065957213382
0.1725592902827419
0.17214320841262004
0.17215227162155758
0.17245474108048975
0.17200432665568902
0.17279177934415535
0.1717683290010462
0.17174083737860132
0.1714399227380621
0.1705219974646979
0.17026453682846684
0.17083985832872603
0.1706361747830641
0.1712369479598984
0.170768717498246
0.17032913956431114
0.17067486401937604
0.17182391428898547
0.17159702958974513
0.17168913446553277
0.17144636171544067
0.17113907054032337
0.171595831639208
0.17189199751826814
0.1717872571118399
0.17198053862904022
0.1724047201192254
0.17224558479173765
0.17317707354771195
0.17249121450014662
0.1721425570082983
0.17300202260049063
0.17239149917915475
0.17258819814168974
0.17252878898812563
0.17233929674975296
0.17279306005220932
0.1725124925079447
0.17339025949740067
0.17271617334047226
0.17478620160423658
0.1737231500492816
0.17392426525212468
0.17404351312575664
0.1740485360274334
0.17403713788782704
0.1736891955822901
0.1734990312314022
0.1732785677107834
0.17230785591516223
0.17406400672515623
0.17474918746331672
0.17485966499180333
0.17499377714578246
0.1743640124833662
0.17435611105500312
0.17348456984797933
0.1738212322896442
0.17255207608734646
0.17260875522783983
0.17303928006386352
0.17325197335882478
0.17337377258600384
0.17368917962798663
0.17372986060173604
0.17440731191312595
0.1761742031722013
0.1754421489906588
0.17550086481053237
0.17561230175277304
0.17439100970932772
0.1742029544043899
0.17500217923845923
0.17486804748392137
0.17449684664999202
0.17533640360808422
0.17529518859640075
0.17606389978862286
0.17608271457993596
0.17704771621888543
0.17784990672400444
0.17747940221012376
0.17731678963869169
0.17666373880500447
0.1767669704290147
0.17598676795013704
0.17573759620121657
0.17551625259682535
0.17579200870706577
0.17542178217948826
0.17594892926266695
0.17594242759356726
0.17539922783333514
0.17469006132455558
0.1752845108405226
0.17547952898772687
0.17538468564472978
0.17508209476295225
0.17592523471103047
0.17657330778704813
0.1761840401495189
0.17456060750970423
0.17436065626753977
0.17545150055211797
0.17512468242217147
0.17329346266097287
0.17452966083079616
0.17488221271942633
0.1744338668515483
0.17594749375615304
0.17560266853212114
0.1762760059960145
0.17569946914466775
0.17608767742849396
0.17561941253768074
0.1766967915537099
0.17546541135424792
0.17626474190863803
0.17503384814091452
0.17523078611619475
0.17535672565877417
0.17446490719859575
0.17440112274288558
0.1750642035230936
0.1753913933540592
0.17433173461827717
0.17545942739370413
0.17590905212001934
0.1754300512048232
0.17597291967442463
0.1750999592359109
0.17612079263127775
0.1759154028778763
0.17680157942248304
0.17653987488400963
0.17527513101562656
0.1752555003074208
0.17534730678222066
0.17551836597709355
0.17542354812514924
0.17543483140681967
0.17556370266643379
0.17525080081195993
0.1754160822338056
0.1755938339748538
0.17512130582949945
0.17659944618473006
0.17575785978285746
0.1759530039269987
0.17673580526260593
0.17605268978823116
0.1759448758319378
0.17703821459115332
0.17604998612050105
0.17638868041290112
0.17708180196997805
0.17550228337826393
0.17563259229092068
0.17554203790964096
0.17485061058938312
0.17651594614883662
0.17591965651306485
0.17676094983493063
0.17776490636525913
0.17670585488544527
0.1763662122462358
0.177357780017128
0.17626075653987272
0.1765545426788318
0.17679938226199285
0.17660855012895552
0.17635664092335268
0.17645568243953264
0.1765557203182824
0.17635645994565768
0.17606822685841053
0.17624075710904924
0.17788187920230206
0.1782682120442601
0.17737541974496893
0.17750283814735607
0.17662096854232429
0.17619092880478865
0.17641712159762363
0.17715168547333668
0.17763781983213453
In [103]:
plt.plot(error)
plt.xlabel('K-SVD iteration')
plt.ylabel('Frobenius relative error')
plt.show()

Example of groundtruth versus its sparse approximation

Left: Sparse approximation in the learned dictionary for sample_id image.

Right: Groundtruth for sample_id image.

In [96]:
sample_id = 15
plt.subplot(2, 2, 1)
plt.imshow((D@X)[:,sample_id].reshape(28,28), cmap='gray')
plt.subplot(2, 2, 2)
plt.imshow(Y[:,sample_id].reshape(28,28), cmap='gray')

sample_id = 900
plt.subplot(2, 2, 3)
plt.imshow((D@X)[:,sample_id].reshape(28,28), cmap='gray')
plt.subplot(2, 2, 4)
plt.imshow(Y[:,sample_id].reshape(28,28), cmap='gray')
plt.show()

Plot dictionary elements sorted by how often they are used

In [97]:
tmp = np.count_nonzero(X, axis = 1)
order_used = np.argsort(-tmp)
plt.plot(tmp[order_used])
plt.yscale('log')
plt.xlabel('Dictionary element')
plt.ylabel('How many times used')
plt.show()

Draw 64 most commonly used dictionary elements

In [98]:
D_tmp = D[:,order_used[:64]]
D_tmp = D_tmp.reshape(28,28,8,8)
D_tmp = np.moveaxis(D_tmp, 2, 0)
D_tmp = np.moveaxis(D_tmp, 3, 2)
D_tmp = D_tmp.reshape(28*8, 28*8)
plt.figure(num=None, figsize=(8, 6), dpi=200, facecolor='w', edgecolor='k')
plt.imshow(D_tmp, cmap = 'gray')
plt.show()

Final remarks:

It would be interesting to rerun the experiment on a different dataset from MNIST and see what sort of dictionary elements we recover. It might be that for MNIST the variation between digits is low and therefore our dictionary consists of pictures representing limited number of ways of writing each digit.