飞道的博客

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

386人阅读  评论(0)

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例    对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

 

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

写在前面

  随着深度学习的发展,已经有很多学者将深度学习应用到物体瑕疵检测中,如列车钢轨的缺陷检测、医学影像中各种疾病的检测。但是瑕疵检测任务几乎都存在一个共同的难题——缺陷数据太少了。我们使用这些稀少的缺陷数据很难利用深度学习训练一个理想的模型,往往都需要进行数据扩充,即通过某些手段增加我们的缺陷数据。【数据扩充大家感兴趣自己去了解下,GAN网络也是实现数据扩充的主流手段】上面说到的方法是基于缺陷数据来训练的,是有监督的学习,学者们在漫长的研究中,考虑能不能使用一种无监督的方法来实现缺陷检测呢?于是啊,AnoGAN就横空出世了,它不需要缺陷数据进行训练,而仅使用正常数据训练模型,关于AnoGAN的细节后文详细介绍。

​  关于GAN网络,我已经介绍了几篇,如下:

​  在阅读本文之前建议大家对GAN有一定的了解,可以参考[1]和[2],关于[3]感兴趣的可以看看,本篇文章用不到[3]相关知识。

  准备好了嘛,我们开始发车了喔。🚖🚖🚖

AnoGAN 原理详解✨✨✨

  首先我们来看看AnoGAN的全称,即Anomaly Detection with Generative Adversarial Networks ,中文是指使用生成对抗网络实现异常检测。这篇论文解决的是医学影像中疾病的检测,由于对医学相关内容不了解,本文将完全将该算法从论文中剥离,只介绍算法原理,而不结合论文进行讲述。想要了解论文详情的可以点击☞☞☞查看。

​  接下来就随我一起来看看AnoGAN的原理。其实AnoGAN的原理是很简单的,但是我看网上的资料总是说的摸棱两可,我认为主要原因有两点:其一是没有把AnoGAN的原理分步来叙述,其二是有专家视角,它们认为我们都应该明白,但这对于新手来说理解也确实是有一定难度的。

​  在介绍AnoGAN的具体原理时,我先来谈谈AnoGAN的出发点,这非常重要,大家好好感受。我们知道,DCGAN是将一个噪声或者说一个潜在变量映射成一张图片,在我们训练DCGAN时,都是使用某一种数据进行的,如[2]中使用的数据都是人脸,那么这些数据都是正常数据,我们从一个潜在变量经DCGAN后生成的图片应该也都是正常图像。AnoGAN的想法就是我能否将一张图片M映射成某个潜在变量呢,这其实是较难做到的。但是我们可以在某个空间不断的查找一个潜在变量,使得这个潜在变量生成的图片与图片M尽可能接近。这就是AnoGAN的出发点,大家可能还不明白这么做的意义,下文为大家详细介绍。☘☘☘

​  AnoGAN其实是分两个阶段进行的,首先是训练阶段,然后是测试阶段,我们一点点来看:

  • 训练阶段

    训练阶段仅使用正常的数据训练对抗生成网络。如我们使用手写数字中的数字8作为本阶段的数据进行训练,那么8就是正常数据。训练结束后我们输入一个向量z,生成网络会将z变成8。不知道大家有没有发现其实这阶段就是[2]中的DCGAN呢?【注意:训练阶段已经训练好GAN网络,后面的测试阶段GAN网络的权重是不在变换的】

  • 测试阶段

    在训练阶段我们已经训练好了一个GAN网络,在这一阶段我们就是要利用训练好的网络来进行缺陷检测。如现在我们有一个数据6,此为缺陷数据【训练时使用8进行训练,这里的6即为缺陷数据】。现在我们要做的就是搜索一个潜在变量并让其生成的图片与图片6尽可能接近,具体实现如下:首先我们会定义一个潜在变量z,然后经过刚刚训练的好的生成网络,得到假图像G(z),接着G(z)和缺陷数据6计算损失,这时候损失往往会比较大,我们不断的更新z值,会使损失不断的减少,在程序中我们可以设置更新z的次数,如更新500次后停止,此时我们认为将如今的潜在变量z送入生成网络得到的假图像已经和图片6非常像了,于是我们将z再次送入生成网络,得到G(z)。注:由于潜在变量z送入的网络是生成图片8的,尽管通过搜索使G(z)和6尽可能相像,但还是存在一定差距,即它们的损失较大】最后我们就可以计算G(z)和图片6的损失,记为loss1,并将这个损失作为判断是否有缺陷的重要依据。怎么作为判断是否有缺陷的重要依据呢?我再举个例子大家就明白了,现在在测试阶段我们传入的不是缺陷数据,而是正常的数据8,此时应用相同的方法搜索潜在变量z,然后将最终的z送入生成网络,得到G(z),最后计算G(z)和图片8的损失。【注:由于潜在变量z送入的网络是生成图片8的,所以最后生成的G(z)可以和数据8很像,即它们的损失较小】通过以上分析, 我们可以发现当我们在测试阶段传入缺陷图片时最终的损失大,传入正常图片时的损失小,这时候我们就可以设置一个合适的阈值来判断图像是否有缺陷了。🥂🥂🥂
    这一段是整个AnoGAN的重点,大家多思考思考,相信你可以理解。我也画了一个此过程的流程图,大家可以参考一下,如下:


  读了上文,是不是对AnoGAN大致过程有了一定了解了呢!我觉得大家训练阶段肯定是没问题的啦,就是一个DCGAN网络,不清楚这个的话建议阅读[2]了解DCGAN网络。测试阶段的难点就在于我们如何定义损失函数来更新z值,我们直接来看论文中此部分的损失,主要分为两部分,分别是Residual Loss和Discrimination Loss,它们定义如下:

  • Residual Loss

    R ( z ) = ∑ ∣ x − G ( z ) ∣ {\rm{R}}(z) = \sum {|x - G(z)|} R(z)=xG(z)

    上式z表示潜在变量,G(z)表示生成的假图像,x表示输入的测试图片。上式表示生成的假图像和输入图片之间的差距。如果生成的图片越接近x,则 R ( z ) R(z) R(z)越小。

  • Discrimination Loss

    D ( z ) = ∑ ∣ f ( x ) − f ( G ( z ) ) ∣ D(z) = \sum {|f(x) - f(G(z))|} D(z)=f(x)f(G(z))

    上式z表示潜在变量,G(z)表示生成的假图像,x表示输入的测试图片。f(*)表示将*通过判别器,然后取判别器某一层的输出结果。【注:这里使用的并非判别器的最终输出,而是判别器某层的输出,关于这一点,会在代码讲解时介绍】 这里可以把判别器当作一个特征提取网络,我们将生成的假图片和测试图片都输入判别器,看它们提取到特征的差异。同样,如果生成的图片越接近x,则 D ( z ) D(z) D(z)越小。

求得 R ( z ) R(z) R(z) D ( z ) D(z) D(z)后,我们定义它们的线性组合作为最终的损失,如下:

L o s s ( z ) = ( 1 − λ ) R ( z ) + λ D ( z ) Loss(z)=(1-\lambda)R(z)+\lambda D(z) Loss(z)=(1λ)R(z)+λD(z)

通常,我们取 λ = 0.1 \lambda =0.1 λ=0.1


​  到这里,AnoGAN的理论部分都介绍完了喔!!!不知道你理解了多少呢?如果觉得有些地方理解还差点儿意思的话,就来看看下面的代码吧,这回对你理解AnoGAN非常有帮助。🌱🌱🌱

 

AnoGAN代码实战

​  如果大家和我一样找过AnoGAN代码的话,可能就会和我有一样的感受,那就是太乱了。怎么说呢,我认为从原理上来说,应该很好实现AnoGAN,但是我看Github上的代码写的挺复杂,不是很好理解,有的甚至起着AnoGAN的名字,实现的却是一个简单的DCGAN网络,着实让人有些无语。于是我打算按照自己的思路来实现一个AnoGAN,奈何却出现了各种各样的Bug,正当我心灰意冷时,看到了一篇外文的博客,写的非常对我的胃口,于是按照它的思路实现了AnoGAN。这里我还是想感概一下,我发现很多外文的博客确实写的非常漂亮,我想这是值得我们学习的地方!!!🌼🌼🌼

 

代码下载地址✨✨✨

​  本次我将源码上传到我的Github了,大家可以阅读README文件了解代码的使用,Github地址如下:

AnoGAN-pytorch实现

​  我认为你阅读README文件后已经对这个项目的结构有所了解,我在下文也会帮大家分析分析源码,但更多的时间大家应该自己动手去亲自调试,这样你会有不一样的收获。🌾🌾🌾

 

数据读取✨✨✨

​  本次使用的数据为mnist手写数字数据集,我们下载的是.csv格式的数据,这种格式方便读取。读取数据代码如下:

## 读取训练集数据  (60000,785)
train = pd.read_csv(".\data\mnist_train.csv",dtype = np.float32)
## 读取测试集数据  (10000,785)
test = pd.read_csv(".\data\mnist_test.csv",dtype = np.float32)

  我们可以来看一下mnist数据集的格式是怎样的,先来看看train中的内容,如下:

​  train的shape为(60000,785),其表示训练集中共有60000个数据,即60000张手写数字的图片,每个数据都有785个值。我们来分析一下这785个数值的含义,第一个数值为标签label,表示其表示哪个手写数字,后784个数值为对应数字每个像素的值,手写数字图片大小为28×28,故一共有784个像素值。

​  解释完训练集数据的含义,那测试集也是一样的啦,只不过数据较少,只有10000条数据,test的内容如下:

​  大家需要注意的是,上述的训练集和测试集中的数据我们今天并不会全部用到。我们取训练集中的前400个标签为7或8的数据作为AnoGAN的训练集,即7、8都为正常数据。取测试集前600个标签为2、7、8作为测试数据,即测试集中有正常数据(7、8)和异常数据(2),相关代码如下:

# 查询训练数据中标签为7、8的数据,并取前400个
train = train.query("label in [7.0, 8.0]").head(400)

# 查询训练数据中标签为7、8的数据,并取前400个
test = test.query("label in [2.0, 7.0, 8.0]").head(600)

  可以看看此时的train和test的结果:

​  在AnoGAN中,我们是无监督的学习,因此是不需要标签的,通过以下代码去除train和test中的标签:

# 取除标签后的784列数据
train = train.iloc[:,1:].values.astype('float32')
test = test.iloc[:,1:].values.astype('float32')

  去除标签后train和test的结果如下:

  可以看出,此时train和test中已经没有了label类,它们的第二个维度也从785变成了784。

  最后,我们将train和test reshape成图片的格式,即28×28,代码如下:

# train:(400,784)-->(400,28,28)
# test:(600,784)-->(600,28,28)
train = train.reshape(train.shape[0], 28, 28)
test = test.reshape(test.shape[0], 28, 28)

​  此时,train和test的维度发生变换,如下图所示:

  至此,我们的数据读取部分就为大家介绍完了,是不是发现挺简单的呢,加油吧!!!🥂🥂🥂

 

模型搭建

​  模型搭建真滴很简单!!!大家之间看代码吧。🌻🌻🌻

生成模型搭建

"""定义生成器网络结构"""
class Generator(nn.Module):

  def __init__(self):
    super(Generator, self).__init__()

    def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.ReLU(inplace=True), bn=True):
        seq = []
        seq += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
        if bn is True:
          seq += [nn.BatchNorm2d(out_channel)]
        seq += [activation]

        return nn.Sequential(*seq)

    seq = []
    seq += [CBA(20, 64*8, stride=1, padding=0)]
    seq += [CBA(64*8, 64*4)]
    seq += [CBA(64*4, 64*2)]
    seq += [CBA(64*2, 64)]
    seq += [CBA(64, 1, activation=nn.Tanh(), bn=False)]

    self.generator_network = nn.Sequential(*seq)

  def forward(self, z):
      out = self.generator_network(z)

      return out

 

为了帮助大家理解,我绘制 了生成网络的结构图,如下:

 
 

判别模型搭建

"""定义判别器网络结构"""
class Discriminator(nn.Module):

  def __init__(self):
    super(Discriminator, self).__init__()

    def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.LeakyReLU(0.1, inplace=True)):
        seq = []
        seq += [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
        seq += [nn.BatchNorm2d(out_channel)]
        seq += [activation]

        return nn.Sequential(*seq)

    seq = []
    seq += [CBA(1, 64)]
    seq += [CBA(64, 64*2)]
    seq += [CBA(64*2, 64*4)]
    seq += [CBA(64*4, 64*8)]
    self.feature_network = nn.Sequential(*seq)

    self.critic_network = nn.Conv2d(64*8, 1, kernel_size=4, stride=1)

  def forward(self, x):
      out = self.feature_network(x)

      feature = out
      feature = feature.view(feature.size(0), -1)

      out = self.critic_network(out)

      return out, feature


 

  同样,为了方便大家理解,我也绘制了判别网络的结构图,如下:

  这里大家需要稍稍注意一下,判别网络有两个输出,一个是最终的输出,还有一个是第四个CBA BLOCK提取到的特征,这个在理论部分介绍损失函数时有提及。

 

模型训练

数据集加载

class image_data_set(Dataset):
    def __init__(self, data):
        self.images = data[:,:,:,None]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(64, interpolation=InterpolationMode.BICUBIC),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.transform(self.images[idx])
        
 # 加载训练数据
 train_set = image_data_set(train)
 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

 

​  这部分不难,但我提醒大家注意一下这句:transforms.Resize(64, interpolation=InterpolationMode.BICUBIC),即我们采用插值算法将原来28*28大小的图片上采样成了64*64大小。【感兴趣的这里也可以不对其进行上采样,这样的话大家需要修改一下上节的模型,可以试试效果喔】

 

加载模型、定义优化器、损失函数等参数

# 指定设备
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# batch_size默认128
batch_size = args.batch_size
# 加载模型
G = Generator().to(device)
D = Discriminator().to(device)

# 训练模式
G.train()
D.train()

# 设置优化器
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0004, betas=(0.0, 0.9))

# 定义损失函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')

 

 

训练GAN网络

"""
训练
"""

# 开始训练
for epoch in range(args.epochs):
    # 定义初始损失
    log_g_loss, log_d_loss = 0.0, 0.0
    for images in train_loader:
        images = images.to(device)

        ## 训练判别器 Discriminator
        # 定义真标签(全1)和假标签(全0)   维度:(batch_size)
        label_real = torch.full((images.size(0),), 1.0).to(device)
        label_fake = torch.full((images.size(0),), 0.0).to(device)

        # 定义潜在变量z    维度:(batch_size,20,1,1)
        z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
        # 潜在变量喂入生成网络--->fake_images:(batch_size,1,64,61)
        fake_images = G(z)

        # 真图像和假图像送入判别网络,得到d_out_real、d_out_fake   维度:都为(batch_size,1,1,1)
        d_out_real, _ = D(images)
        d_out_fake, _ = D(fake_images)

        # 损失计算
        d_loss_real = criterion(d_out_real.view(-1), label_real)
        d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
        d_loss = d_loss_real + d_loss_fake

        # 误差反向传播,更新损失
        optimizerD.zero_grad()
        d_loss.backward()
        optimizerD.step()

        ## 训练生成器 Generator
        # 定义潜在变量z    维度:(batch_size,20,1,1)
        z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
        fake_images = G(z)

        # 假图像喂入判别器,得到d_out_fake   维度:(batch_size,1,1,1)
        d_out_fake, _ = D(fake_images)

        # 损失计算
        g_loss = criterion(d_out_fake.view(-1), label_real)

        # 误差反向传播,更新损失
        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()

        ## 累计一个epoch的损失,判别器损失和生成器损失分别存放到log_d_loss、log_g_loss中
        log_d_loss += d_loss.item()
        log_g_loss += g_loss.item()

    ## 打印损失
    print(f'epoch {
     epoch}, D_Loss:{
     log_d_loss / 128:.4f}, G_Loss:{
     log_g_loss / 128:.4f}')




## 展示生成器存储的图片,存放在result文件夹下的G_out.jpg
z = torch.randn(8, 20).to(device).view(8, 20, 1, 1).to(device)
fake_images = G(z)
torchvision.utils.save_image(fake_images,f"result\G_out.jpg")

 

​  这部分就是训练一个DCGAN网络,到目前为止其实也都可以认为是DCGAN的内容。我们可以来看一下输出的G_out.jpg图片:

​  这里我们可以看到训练是有了效果的,但会发现不是特别好。我分析有两点原因,其一是我们的模型不好,且GAN本身就容易出现模式崩溃的问题;其二是我们的数据选取的少,在数据读取时训练集我们只取了前400个数据,但实际上我们一共可以取12116个,大家可以尝试增加数据,我想数据多了后效果肯定比这个好,大家快去试试吧!!!🍉🍉🍉

 

缺陷检测✨✨✨

​  这部分才是AnoGAN的重点,首先我们先定义损失的计算,如下:

## 定义缺陷计算的得分
def anomaly_score(input_image, fake_image, D):
# Residual loss 计算
residual_loss = torch.sum(torch.abs(input_image - fake_image), (1, 2, 3))

# Discrimination loss 计算
_, real_feature = D(input_image)
_, fake_feature = D(fake_image)
discrimination_loss = torch.sum(torch.abs(real_feature - fake_feature), (1))

# 结合Residual loss和Discrimination loss计算每张图像的损失
total_loss_by_image = 0.9 * residual_loss + 0.1 * discrimination_loss
# 计算总损失,即将一个batch的损失相加
total_loss = total_loss_by_image.sum()

return total_loss, total_loss_by_image, residual_loss

 

  大家可以对比一下理论部分损失函数的介绍,看看是不是一样的呢。

  接着我们就需要不断的搜索潜在变量z了,使其与输入图片尽可能接近,代码如下:

# 加载测试数据
test_set = image_data_set(test)
test_loader = DataLoader(test_set, batch_size=5, shuffle=False)
input_images = next(iter(test_loader)).to(device)

# 定义潜在变量z  维度:(5,20,1,1)
z = torch.randn(5, 20).to(device).view(5, 20, 1, 1)
# z的requires_grad参数设置成Ture,让z可以更新
z.requires_grad = True

# 定义优化器
z_optimizer = torch.optim.Adam([z], lr=1e-3)

# 搜索z
for epoch in range(5000):
    fake_images = G(z)
    loss, _, _ = anomaly_score(input_images, fake_images, D)

    z_optimizer.zero_grad()
    loss.backward()
    z_optimizer.step()

    if epoch % 1000 == 0:
    print(f'epoch: {
     epoch}, loss: {
     loss:.0f}')


 

​  执行完上述代码后,我们得到了一个较理想的潜在变量,这时候再用z来生成图片,并基于生成图片和输入图片来计算损失,同时,我们也保存了输入图片和生成图片,并打印了它们之前的损失,相关代码如下:

    fake_images = G(z)

    _, total_loss_by_image, _ = anomaly_score(input_images, fake_images, D)

    print(total_loss_by_image.cpu().detach().numpy())

    torchvision.utils.save_image(input_images, f"result/Nomal.jpg")
    torchvision.utils.save_image(fake_images, f"result/ANomal.jpg")

  我们可以来看看最后的结果哦,如下:

​  可以看到,当输入图像为2时(此为缺陷),生成的图像也是8,它们的损失最高为464040.44。这时候如果我们设置一个阈值为430000,高于这个阈值的即为异常图片,低于这个阈值的即为正常图片,那么我们是不是就可以通过AnoGAN来实现缺陷的检测了呢!!!🍒🍒🍒

 

总结

  到这里,AnoGAN的所有内容就介绍完了,大家好好感受感受它的思想,其实是很简单的,但是又非常巧妙。最后我不知道大家有没有发现AnoGAN一个非常明显的缺陷,那就是我们每次在判断异常时要不断的搜索潜在变量z,这是非常耗时的。而很多任务对时间的要求还是很高的,所以AnoGAN还有许多可以改进的地方,后续博文我会带大家继续学习GAN网络在缺陷检测中的应用,我们下期见。🖐🏽🖐🏽🖐🏽

 

参考文献

Unsupervised Anomaly Detection with Generative Adversarial Networks to Guide Marker Discovery 🍁🍁🍁

AnoGAN|GAN做图像异常检测的奠基石 🍁🍁🍁

GAN 使用 Pytorch 进行异常检测的方法 🍁🍁🍁

深度学习论文笔记(异常检测)—— Generative Adversarial Networks to Guide Marker Discovery 🍁🍁🍁

 
 
如若文章对你有所帮助,那就🛴🛴🛴


转载:https://blog.csdn.net/qq_47233366/article/details/127492230
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场