import numpy as np
from scipy.sparse.linalg import svds
# pywt for wavelet transform implementation
import pywt
import pywt.data
# Wrapper function around pywt transform (to get vector as an output)
from func.transforms2d import transform_wavelet2
# matplotlib for plotiing
import matplotlib.pyplot as plt
%matplotlib inline
Techninques for finding and using sparse representations are widely used in image processing (e.g. lossy image and lossy video compression)
A simple demonstration of how the cameraman image $y\in R^{512^2}$ can be compressed using Daubechies 4 wavelet transform:
x = pywt.data.camera()
x = x / np.linalg.norm(image, 'fro')
plt.imshow(x, cmap ='gray')
plt.show()
Psi = transform_wavelet2('Db4', mode = 'per', level = 5)
x = Psi.forward(image)
plt.semilogy(np.sort(np.abs(x))[::-1], linewidth=2)
plt.show()
Construct $\Psi$ to be a wavelet transform (Daubechies 4 order). Apply $x = \Psi y$. Now $x\in R^{512^2}$ is a vector of wavelet coefficients.
Notice in the figure above fast decay of coefficient values. This suggests that most of the information content of the image is contained only in few coefficients.
Keep only $5\%$ of the largest coefficients (in absolute value). Transform back into the image domain and observe the image. $$ x_k = HT(x, 5\%)\\ y_k = \Psi^*(x_k) $$
def thresh_hard_sparse(x, k):
"""
Keep only k largest entries of x and return their indices.
Parameters
----------
x : numpy array
Numpy array to be thresholded
k : int
Number of largest entries in absolute value to keep
Notes
"""
_x = x.copy()
ind = np.argpartition(abs(_x), -k, axis=None)[-k:]
ind = np.unravel_index(ind, _x.shape)
ind_del = np.ones(_x.shape, dtype=bool)
ind_del[ind] = False
_x[ind_del] = 0
return ind, _x
rho = 0.05
_, x_k = thresh_hard_sparse(x, round(rho*512**2))
image_k = Psi.backward(x_k)
fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(14, 3))
ax1.semilogy(np.sort(np.abs(x))[::-1], linewidth=2)
ax2.semilogy(np.sort(np.abs(x_k))[::-1], linewidth=2)
plt.show()
fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(16, 12))
ax1.imshow(image, cmap ='gray')
ax1.set_title('Cameraman')
ax2.imshow(image_k, cmap ='gray')
ax2.set_title('Cameraman thresholded')
plt.show()
First implement NIHT (Blumensath & Davies 2010) for solving underdetermined linear system with sparse constraint:
$$ \min_{x} \| Ax - y \|_2^2, \qquad \text{s.t.}\quad \|x\|_0\leq k, $$where $A\in R^{m\times n}$, $x\in R^n$, and $y\in R^{m}$, (where $n\geq m \geq k$).
Note: Some of you who took C6.1 Numerical Linear Algebra last term and know of Conjugate Gradient (CG) method for solving linear systems (Hestenes & Stiefel 1952), there is an equivalent method for solving linear underdetermined systems with sparsity constraints CGIHT (Blanchard et al.2015) which has faster convergence rates.
def support_projection(x, ind):
"""
Keeps only coefficients at specified indices, setting others to zero.
----------
x : numpy array
Numpy array to be projected
ind : int
where to keep entries keep
"""
_x = x.copy()
ind_del = np.ones(x.shape, dtype=bool)
ind_del[ind] = False
_x[ind_del] = 0
return x
def niht(A, y, k, tol = 1e-4, MAX_ITER = 100):
"""
Normalized Iterative Hard Thresholding solving underdetermined system
Ax = b,
Parameters
----------
A : numpy array
Underdetermined matrix
y : numpy array
Right-hand side vector
k : int
Sparsity constraint on x
Notes
"""
error = np.zeros(MAX_ITER)
w = A.T.dot(y)
T_k, x = thresh_hard_sparse(w, k)
error[0] = np.linalg.norm(A.dot(x) - y)/np.linalg.norm(y)
# Iterative process
l = 2
not_finished = True
while not_finished:
r = A.T.dot(y-A.dot(x))
r_proj = support_projection(r, T_k)
a = np.linalg.norm(r_proj)**2
b = np.linalg.norm(A.dot(r_proj))**2
alpha = a/b
w = x + alpha * r
T_k, x = thresh_hard_sparse(w, k)
error[l-1] = np.linalg.norm(A.dot(x) - y)/np.linalg.norm(y)
not_finished = (l < MAX_ITER) and (error[l-1] >= tol)
l = l + 1
return (x, error[:(l-1)])
Here we test whether NIHT solver works for a generated toy problem. Consider a matrix $A\in R^{200\times 500}$ with i.i.d. entries sampled from a Gaussian distribution. Generate a random $k$-sparse $x\in R^{500}$, $\|x\|_0 = k$ and observe $y = Ax$. We run NIHT to test whether we can recover $x$ only from knowing $y$ and $A$.
m = 200
n= 500
k = 10
x = np.random.randn(n, 1)
A = np.random.randn(m, n) / np.sqrt(m)
ind, x = thresh_hard_sparse(x, k)
y = np.matmul(A, x, out=None)
x, error = niht(A, y, k, tol = 1e-4, MAX_ITER = 100)
plt.plot(error)
plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('Relative error')
plt.show()
Basic implementation of K-SVD (Aharon, Elad, Bruckstein, 2006). Solve the following dictionary learning problem:
$$ \min_{D,X} \| Y - DX \|_F^2, \qquad \text{s.t.} \quad \|x_i\|_0\leq T, $$where $Y\in R^{n\times N}$, $D\in R^{n\times K}$, $X\in R^{K\times N}$, with $N \geq n \geq K$.
In the class, we discussed Label Consistent KSVD (LC-KSVD) and Discriminative KSVD (D-KSVD) Label Consistent K-SVD: Learning a Discriminative Dictionary for Recognition (Jiang et al. 2013) On the Equivalence of the LC-KSVD and the D-KSVD Algorithms (Kviatkovsky et al. 2016)
There is work on analysing convolution neural networks as a sequence of sparse dictionary learning problems. Convolutional Neural Networks Analyzed via Convolutional Sparse Coding (Papyan et al. 2016)
def ksvd(Y,K,T, tol = 1e-4, MAX_ITER = 100, D_init = np.array([])):
"""
Naive implementation of K-SVD, solving
Y \approx D X,
Parameters
----------
Y : numpy array
Matrix of data samples
K : int
Number of dictionary elements
T : int
Sparsity constraint on how many dictionary elements to be used for every sample
Notes
"""
error = np.zeros(MAX_ITER)
# Get sizes of the problem
n, N = Y.shape
# Gaussian random initialization for D
# D = np.random.randn(n, K) / np.sqrt(n)
# Initialize dictionary with data
if not D_init.size:
D = Y[:,:K] / np.linalg.norm(Y[:,:K], axis=0)
else:
D = D_init.copy()
X = np.zeros((K,N))
# Iterative process
J = 1
not_finished = True
while not_finished:
# Sparse Coding Stage
for i in range(N):
#print(D.shape, Y[:,i].shape, Y.shape)
x_tmp,_ = niht(D, Y[:,i], T, tol = 1e-7, MAX_ITER = 80)
X[:,i] = x_tmp
error[J-1] = np.linalg.norm(Y - D @ X, ord='fro')/np.linalg.norm(Y, ord='fro')
print(error[J-1])
# Codebook Update Stage
for k in range(K):
w_k = np.where(X[k,:] != 0)[0]
if w_k.size > 1:
ind = np.ones((K,), bool)
ind[k] = False
E_k = Y - D[:,ind] @ X[ind,:]
u, s, vt = svds(E_k[:,w_k], k = 1)
D[:,k] = u.T
X[k,w_k] = s[0] * vt
not_finished = (J < MAX_ITER) and (error[J-1] >= tol)
J = J + 1
return (D,X,error)
Warning: Takes a while to run. Setting $N$ (number of images used) and $K$ (number of dictionary elements) lower makes for faster computation. Making the sparse constraint $T$ higher allows for less sparse representations, resulting in lower error (at the expense of more dense representations).
# Load MNIST using torchvision.datasets
from torchvision import datasets
data = datasets.MNIST('data', train=True, download=True)
X_numpy = data.train_data.numpy()/255
y_numpy = data.train_labels.numpy()
def draw_MNIST(image, label = ''):
# Take a numpy array of 784 entries
plt.imshow(image.reshape(28,28), cmap='gray')
plt.title(label)
plt.show()
def onehot(integer_labels):
#Return matrix whose rows are onehot encodings of integers.
onehotL = np.zeros((len(integer_labels), len(np.unique(integer_labels))), dtype='uint8')
onehotL[np.arange(len(integer_labels)), integer_labels] = 1
return onehotL
X_numpy = X_numpy.reshape(60000,-1)
y_numpy = onehot(y_numpy)
print(X_numpy.shape, y_numpy.shape)
N = 1500
Y = X_numpy[:N].T
D, X, error = ksvd(Y, K = 28**2, T = 15, MAX_ITER = 40, D_init = D)
plt.plot(error)
plt.xlabel('K-SVD iteration')
plt.ylabel('Frobenius relative error')
plt.show()
Left: Sparse approximation in the learned dictionary for sample_id image.
Right: Groundtruth for sample_id image.
sample_id = 15
plt.subplot(2, 2, 1)
plt.imshow((D@X)[:,sample_id].reshape(28,28), cmap='gray')
plt.subplot(2, 2, 2)
plt.imshow(Y[:,sample_id].reshape(28,28), cmap='gray')
sample_id = 900
plt.subplot(2, 2, 3)
plt.imshow((D@X)[:,sample_id].reshape(28,28), cmap='gray')
plt.subplot(2, 2, 4)
plt.imshow(Y[:,sample_id].reshape(28,28), cmap='gray')
plt.show()
tmp = np.count_nonzero(X, axis = 1)
order_used = np.argsort(-tmp)
plt.plot(tmp[order_used])
plt.yscale('log')
plt.xlabel('Dictionary element')
plt.ylabel('How many times used')
plt.show()
D_tmp = D[:,order_used[:64]]
D_tmp = D_tmp.reshape(28,28,8,8)
D_tmp = np.moveaxis(D_tmp, 2, 0)
D_tmp = np.moveaxis(D_tmp, 3, 2)
D_tmp = D_tmp.reshape(28*8, 28*8)
plt.figure(num=None, figsize=(8, 6), dpi=200, facecolor='w', edgecolor='k')
plt.imshow(D_tmp, cmap = 'gray')
plt.show()
It would be interesting to rerun the experiment on a different dataset from MNIST and see what sort of dictionary elements we recover. It might be that for MNIST the variation between digits is low and therefore our dictionary consists of pictures representing limited number of ways of writing each digit.