飞道的博客

CNN特征提取结果可视化——hooks简单应用

577人阅读  评论(0)

本文代码地址https://github.com/njulhy/funny_code/blob/main/cnn_feature_visualization.ipynb

CNN特征提取结果可视化——hooks简单应用

在神经网络搭建时可能出现各式各样的错误,使用hook而非print或者简单的断点调试有助于你更清晰的意识到错误所在。

hook的使用场景多种多样,本文将使用hooks来简单可视化卷积神经网络的特征提取。用到的神经网络框架为Pytorch

Hooks简单介绍

每个hook都是预先定义好的可调用对象,在pytorch框架中,每个nn.Module对象都能够方便地注册(定义)一个hook。当一些trigger方法调用(如forward()backward())后,注册了hook的nn.Module对象会将相关信息传递到hook里面去。
在PyTorch中,可以注册三种hook:

  1. forward prehook (在forward之前执行)
  2. forward hook (在forward之后执行)
  3. backward hook (在backward之后执行)

具体理解每种hook的使用不是本文讨论的范围,我们将通过一个生动的卷积神经网络可视化例子来介绍hook的使用

CNN特征提取的简单可视化

我们将要进行的工作包括:

  1. 创建CNN特征提取器,本文使用PyTorch自带的resnet34
  2. 创建一个保存hook内容的对象
  3. 为每个卷积层创建hook
  4. 读取图像并进行特征提取
  5. 查看卷积层特征提取效果

本文将对下图进行特征提取并可视化

创建CNN特征提取器

import torch
import torchvision

feature_extractor = torchvision.models.resnet34(pretrained=True)
if torch.cuda.is_available():
	feature_extractor.cuda()

创建保存hook内容的对象

class SaveOutput:
	def __init__(self):
		self.outputs = []
	def __call__(self, module, module_in, module_out):
		self.outputs.append(module_out)
	def clear(self):
		self.outputs=[]
		
save_output = SaveOutput()

为卷积层注册hook

hook_handles = []

for layer in feature_extractor.modules():
	if isinstance(layer, torch.nn.Conv2d):
		handle = layer.register_forward_hook(save_output)
		hook_handles.append(handle)

读取图像并进行特征提取

cat.jpg地址

from PIL import Image
from torchvision import transforms as T

image = Image.open('cat.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)

out = feature_extractor(X)

查看卷积层特征提取效果

对于resnet来说,其具体结构如下:

卷积层共有1+6+(4*2+1)+(6*2+1)+(3*2+1)=36个,对conv3_x层有4*2+1卷积层的原因是(1)四个basicblock本身有4*2个卷积层(2)其中一个basicblock进行了downsample,又多了一个卷积层

查看卷积层数

此时每个卷积层的结果都通过hook保存到了save_output.outputs里面,我们查看是否为36个结果

可见全部卷积层的输出都保存了下来

可视化第一个卷积层

对resnet34来说,首个卷积层的卷积核为7*7,将输入的三通道彩色图像通道增加至64,尺寸从224*224对折为112*112,tensor的shape为1x64x112x112

我们对首个卷积层的提取结果进行可视化:

import matplotlib.pyplot as plt
plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[0].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

emm这是第一个卷积层的提取结果,可爱的小猫咪开始黑化

可视化第二、七个卷积层

对resnet34来说,第2-7个卷积层tensor的shape为64x1x56x56,我们对其2个卷积层输出进行可视化:

plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[1].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

可见第二个卷积层的结果更加模糊一些

第2-7个卷积层tensor的shape为64x1x56x56,我们对第七个卷积层也可视化:

plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[6].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

可视化第16个卷积层

第16个卷积层对应的是conv3_x的结果,其shape为1x128x28x28,可视化如下

plt.figure(figsize = (15,30))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[15].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

可见图像经过多层特征提取,提取到的特征变得更加高层,大部分通道已经变得难以辨认

结语

对神经网络提取结果进行可视化有助于理解其特征提取逐渐高层化的过程。
hook的使用场景还有很多,希望小伙伴们继续探索。


转载:https://blog.csdn.net/qq_34769162/article/details/115567093
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场