STN网络是谷歌提出来一个模型,可以看成CNN网络中的一个插件,从效果上看,STN网络能通过无监督的方式自动学习,找到图片中的ROI区域并进行一系列逆变换,将ROI区域的物体转换为正常形态,实现空间不变性。本文参考了此篇博客
一个好的分类模型应该能将物体的姿态和形变与纹理和形状解耦,也就是说,即使纹理或形状发生改变,模型也能将其正确分类。CNN网络的max-pooling层对特征的位置具有一定程度的空间不变性,但由于kernel的大小有限,所以作用范围有限;另外,pooling操作是不记录特征的位置的,也就是说只要图片中出现了人特征(如鼻子,眼睛,嘴等)就将其分为人这一类,而并不管这些特征是不是在对应的位置,是不是组合成了一个人脸;再另外,max-pooling操作实际上忽略了feature map中75%的信息,会损失一些信息,kernel是定义好的,不可训练,并且低层的空间不变性并不是很强。
而STN网络则解决了以上问题,STN可以插入到任何网络中,能通过端到端的方式训练,而且输入是feature map的某一个channel,而pooling操作是对所有的输入样本都用一个操作。
在了解STN之前,先了解仿射变换。仿射变换=线性变换+平移,包括平移、旋转、缩放、斜切,可以表示为y=Ax+b,其中A可以实现旋转和缩放,除A外,其余均为向量。使用增广矩阵和增广向量,使用矩阵乘法同时表示线性变换和平移:
这个就是仿射变换的源像素与目标像素坐标变换的公式,其中,s和t分别表示源像素和目标像素,等号左侧计算得到的就是经过仿射变换后的像素位置,所以仿射变换就可以简化为6个参数,而STN网络正是要自动学习这6个参数。得到变换后的像素位置后,我们就可以根据源像素和目标像素的对应关系计算目标像素位置的像素值。由于计算得到的目标像素位置不一定是整数,所以目标像素值是通过双线性插值来计算的。
下面是STN网络的介绍。STN网络分成三个部分:Localisation Network, Parameterised Sampling Grid, Differentiable Image Sampling。
首先,定位网络,输入大小为(H, W, C)的feature map,通过全连接或卷积操作输入形状为(6,)的变换矩阵;
其次是参数化采样栅格,这个部分的作用是输出一些点,点的选取是根据输入的图像中应该采样哪些点才能得到想要的输出,也就是未经变换的ROI区域,如下图所示,图中U为原图像,其中的像素点为源像素;V为采样栅格经仿射变换得到的与图U同样大小的图片,其中的像素为目标像素。图a中变换矩阵为单位矩阵,也就是同等映射,ROI区域为原图像;图b中变换矩阵为仿射变换矩阵,ROI区域为倾斜的9。这里实际上就可以表示为刚刚提到的仿射变换的公式,只是将s和t是互换的,也就是说,等号左侧是源像素,等号右侧是目标像素,之所以这样操作是因为,这一部分和下一部分是配套使用的。另外,这一部分的坐标是经过归一化的,坐标范围为[-1,1],目的是让仿射变换的中心在图像的中心,而不是左上角,这样对图像进行变换,如旋转的时候是沿着图像中心,而非左上角
最后是可微分图像采样部分,之前说第二部分和第三部分是搭配使用的,是因为第三部分主要是计算目标像素位置的像素值,这里主要是用双线性插值法。在第二部分计算出当前目标像素位置对应的源像素位置后,结合源像素周围四个像素点的像素值,使用双线性插值的方法计算目标像素值,也就是说,第二部分的输出源像素位置是第三部分的输入,第三部分就得到了目标像素位置的像素值,所有目标像素位置的像素值均被计算后,就得到了目标图像,也就是经仿射变换后的图像,完成了整个仿射变换的过程。
下面是一些STN网络的结果
STN网络结果的变换过程可以参考这个网址,当然STN并不局限于仿射变换,不同的变换方法只是参数数量不同而已。
心塞,上一次更新内容时没有提交源文件,导致只生成了网页,没有对应的源文件,再次更新时,只能把之前的两篇又重新写了一遍。事实证明,实践出真知,还是要多动手,多应用呀!