一维向量和二维矩阵卷积的等价矩阵乘法实现

经过学习和探索,掌握了一维卷积和二维卷积的矩阵乘法实现。

对于一维向量的卷积,Matlab提供了convmtx计算向量对应的乘法矩阵。它是一个Toeplitz矩阵。

% By TomHeaven, hanlin_tan@nudt.edu.cn @ 2016.09.28
x = [1 2 3]';
y = [4 5 6]';
 
n = length(y);
H = convmtx(x,n);
z = H*y;
 
z_c = conv(x, y);
 
diff = norm(z - z_c)

二维卷积(图像与核卷积)的矩阵乘法形式,矩阵的构造方法可以通过简单的算理推演得到。它是一个Band(带宽)矩阵。

设二维图像矩阵为$X \in R^{m \times n}$,卷积核矩阵为$Y \in R^{p \times p}$。它们按行优先展开得到的向量分别为$x,y$。


$$
X= \left[
\begin{matrix}
1 & 2 & 3 & 4 & 5\\
6 & 7 & 8 & 9 & 10 \\
11 & 12 & 13 & 14 & 15
\end{matrix}
\right]
$$

$$
Y= \left[
\begin{matrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
7 & 8 & 9
\end{matrix}
\right]
$$

再令$Z = X * Y$,$z$为$Z$按行优先展开所得的向量。

则可以构造矩阵
$$
H = \left[
\begin{matrix}
9 & 8 & 7 & 0 & 0 & 6 & 5 & 4 & 0 & 0 & 3 & 2 & 1 & 0 & 0 \\
0 & 9 & 8 & 7 & 0 & 0 & 6 & 5 & 4 & 0 & 0 & 3 & 2 & 1 & 0 \\
0 & 0 & 9 & 8 & 7 & 0 & 0 & 6 & 5 & 4 & 0 & 0 & 3 & 2 & 1 \\
\end{matrix}
\right]
$$
使得
$$
Hx = z
$$

以上示例已经蕴含了$H$的构造法则:
1. $H$的大小为$m \times {pn}$。
2. 第一行划分为$k$组,每组$m$个元素。第$i$组前$k$个元素为$Y$的$p-i+1$行元素倒序排列,剩下$n-p$个元素为0。
3. 第i+1行为第i行右移一位,末尾的元素(0)补到行的开头。

用Matlab实现为

% By TomHeaven, hanlin_tan@nudt.edu.cn @ 2016.09.28
x = (1 : 24)';
y = (1 : 9)';
 
X = reshape(x, [4 6])';
Y = reshape(y, [3 3])'; % square kernel
 
% construct convolutional multiplication matrix H
Y_hat = [ fliplr(flipud(Y)) zeros(size(Y, 1), size(X,2) - size(Y, 2))];
y_hat = reshape(Y_hat', [1, size(Y_hat,1) * size(Y_hat,2)]);
y_hat = [ y_hat zeros(1, (size(X,1) - size(Y,1)) * size(X,2) ) ];
 
H = zeros((size(X,1) - size(Y,1) + 1) * (size(X,2) - size(Y,2) + 1), length(x));
%H(1, :) = y_hat;
len = length(y_hat);
cnt = 0;
for i = 1: size(X,1) - size(Y, 1) + 1
    for j = 1 : size(X,2) - size(Y, 2) + 1
       cnt = cnt + 1;
       H(cnt,:) = y_hat;
       y_hat(2:len) = y_hat(1:len - 1);
       y_hat(1) = 0;
    end
    % skip invalid convolution
    for j = 1 :   size(X,2) - (size(X,2) - size(Y, 2) + 1)
       y_hat(2:len) = y_hat(1:len - 1);
       y_hat(1) = 0;
    end
end
 
z = H * x;
z = reshape(z', [(size(X,2) - size(Y,2) + 1), (size(X,1) - size(Y,1) + 1)   ])';
 
Z = conv2(X, Y, 'valid');
 
diff = norm(z - Z)