小言_互联网的博客

【Keras+计算机视觉+Tensorflow】DCGAN对抗生成网络在MNIST手写数据集上实战(附源码和数据集 超详细)

402人阅读  评论(0)

需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、生成对抗网络的概念

生成对抗网络(GANs,Generative Adversarial Nets),由Ian Goodfellow在2014年提出的,是当今计算机科学中最有趣的概念之一。GAN最早提出是为了弥补真实数据的不足,生成高质量的人工数据。GAN的主要思想是通过两个模型的对抗性训练。随着训练过程的推进,生成网络(Generator,G)逐渐变得擅长创建看起来真实的图像,而判别网络(Discriminator,D)则变得更擅长区分真实图像和生成器生成的图像。GAN网络不局限于提高单一网络的性能,而是希望实现生成器和鉴别器之间的纳什均衡。

假设在低维空间Z存在一个简单容易采样的分布p(z),例如正态分布 ,生成网络构成一个映射函数G:Z→X,判别网络需要判别输入是来自真实数据X_real还是生成网络生成的数据X_fake,结构示意图如图8-1所示

 

下面给出DCGAN利用LSUN数据库生成卧室样本的例子和生成人脸样本的例子,虽然DCGAN还难以生成高精度的图像样本,但这样的结果已经足够让世人感到惊艳

 

 

二、DCGAN在MNIST手写数据集上实战

通过本程序可以完成两个模型的训练。一个是生成模型,一个是判别模型

1:项目结构如下 

 

代码大致可以分为以下几部分

1:构建生成网络

2:构建判别网络

3:DCGAN网络训练

开始下载模型

 2:效果展示

生成图片如下 可以说效果十分逼真 

这是第一张生成图片 可以看出里面有些字体还是略微不够真实,容易被判别器鉴别出来

 这一张是图片生成的十分逼真,几乎没有什么缺点

三、代码 

部分代码如下 全部代码和数据集请点赞关注收藏后评论区留言私信~~~


  
  1. from keras .models import Sequential
  2. from keras .layers import Dense
  3. from keras .layers import Reshape
  4. from keras .layers .core import Activation
  5. from tensorflow .python .keras .layers import BatchNormalization
  6. from keras .layers .convolutional import UpSampling2D
  7. from keras .layers .convolutional import Conv2D, MaxPooling2D
  8. from keras .layers .core import Flatten
  9. from tensorflow .keras .optimizers import SGD
  10. from keras .datasets import mnist
  11. import numpy as np
  12. from PIL import Image
  13. import argparse
  14. import math
  15. def generator_model():
  16. model = Sequential()
  17. model. add( Dense(input_dim= 100, units= 1024))
  18. model. add( Activation( 'tanh'))
  19. model. add( Dense( 128* 7* 7))
  20. model. add( BatchNormalization())
  21. model. add( Activation( 'tanh'))
  22. model. add( Reshape(( 7, 7, 128), input_shape=( 128* 7* 7,)))
  23. model. add( UpSampling2D(size=( 2, 2)))
  24. model. add( Conv2D( 64, ( 5, 5), padding= 'same'))
  25. model. add( Activation( 'tanh'))
  26. model. add( UpSampling2D(size=( 2, 2)))
  27. model. add( Conv2D( 1, ( 5, 5), padding= 'same'))
  28. model. add( Activation( 'tanh'))
  29. return model
  30. def discriminator_model():
  31. model = Sequential()
  32. model. add(
  33. Conv2D( 64, ( 5, 5),
  34. padding= 'same',
  35. input_shape=( 28, 28, 1))
  36. )
  37. model. add( Activation( 'tanh'))
  38. model. add( MaxPooling2D(pool_size=( 2, 2)))
  39. model. add( Conv2D( 128, ( 5, 5)))
  40. model. add( Activation( 'tanh'))
  41. model. add( MaxPooling2D(pool_size=( 2, 2)))
  42. model. add( Flatten())
  43. model. add( Dense( 1024))
  44. model. add( Activation( 'tanh'))
  45. model. add( Dense( 1))
  46. model. add( Activation( 'sigmoid'))
  47. return model
  48. def generator_containing_discriminator(g, d):
  49. model = Sequential()
  50. model. add(g)
  51. d.trainable = False
  52. model. add(d)
  53. return model
  54. def combine_images(generated_images):
  55. num = generated_images.shape[ 0]
  56. width = int(math. sqrt(num))
  57. height = int(math. ceil( float(num)/width))
  58. shape = generated_images.shape[ 1: 3]
  59. image = np. zeros((height*shape[ 0], width*shape[ 1]),
  60. dtype=generated_images.dtype)
  61. for index, img in enumerate(generated_images):
  62. i = int(index/width)
  63. j = index % width
  64. image[i*shape[ 0]:(i+ 1)*shape[ 0], j*shape[ 1]:(j+ 1)*shape[ 1]] = \
  65. img[:, :, 0]
  66. return image
  67. def train(BATCH_SIZE,path):
  68. (X_train, y_train), (X_test, y_test) = mnist. load_data()
  69. X_train = (X_train. astype(np.float32) - 127.5)/ 127.5
  70. X_train = X_train[:, :, :, None]
  71. X_test = X_test[:, :, :, None]
  72. # X_train = X_train. reshape((X_train.shape, 1) + X_train.shape[ 1:])
  73. d = discriminator_model()
  74. g = generator_model()
  75. d_on_g = generator_containing_discriminator(g, d)
  76. d_optim = SGD(lr= 0.0005, momentum= 0.9, nesterov=True)
  77. g_optim = SGD(lr= 0.0005, momentum= 0.9, nesterov=True)
  78. g. compile(loss= 'binary_crossentropy', optimizer= "SGD")
  79. d_on_g. compile(loss= 'binary_crossentropy', optimizer=g_optim)
  80. d.trainable = True
  81. d. compile(loss= 'binary_crossentropy', optimizer=d_optim)
  82. for epoch in range( 100):
  83. print( "Epoch is", epoch)
  84. print( "Number of batches", int(X_train.shape[ 0]/BATCH_SIZE))
  85. for index in range( int(X_train.shape[ 0]/BATCH_SIZE)):
  86. noise = np.random. uniform(- 1, 1, size=(BATCH_SIZE, 100))
  87. image_batch = X_train[index*BATCH_SIZE:(index+ 1)*BATCH_SIZE]
  88. generated_images = g. predict(noise, verbose= 0)
  89. if index % 20 == 0:
  90. image = combine_images(generated_images)
  91. image = image* 127.5+ 127.5
  92. Image. fromarray(image. astype(np.uint8)). save(
  93. str(epoch)+ "_"+ str(index)+ ".png")
  94. X = np. concatenate((image_batch, generated_images))
  95. y = [ 1] * BATCH_SIZE + [ 0] * BATCH_SIZE
  96. d_loss = d. train_on_batch(X, y)
  97. print( "batch %d d_loss : %f" % (index, d_loss))
  98. noise = np.random. uniform(- 1, 1, (BATCH_SIZE, 100))
  99. d.trainable = False
  100. g_loss = d_on_g. train_on_batch(noise, [ 1] * BATCH_SIZE)
  101. d.trainable = True
  102. print( "batch %d g_loss : %f" % (index, g_loss))
  103. if index % 10 == 9:
  104. g. save_weights( 'generator', True)
  105. d. save_weights( 'discriminator', True)
  106. def generate(BATCH_SIZE, nice=False):
  107. g = generator_model()
  108. g. compile(loss= 'binary_crossentropy', optimizer= "SGD")
  109. g. load_weights( 'generator')
  110. if nice:
  111. s = g. predict(noise, verbose= 1)
  112. d_pret = d. predict(generated_images, verbose= 1)
  113. index = np. arange( 0, BATCH_SIZE* 20)
  114. index. resize((BATCH_SIZE* 20, 1))
  115. pre_with_index = list(np. append(d_pret, index, axis= 1))
  116. pre_with_index. sort(key=lambda x: x[ 0], reverse=True)
  117. nice_images = np. zeros((BATCH_SIZE,) + generated_images.shape[ 1: 3], dtype=np.float32)
  118. nice_images = nice_images[:, :, :, None]
  119. for i in range(BATCH_SIZE):
  120. idx = int(pre_with_index[i][ 1])
  121. nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
  122. . predict(noise, verbose= 1)
  123. image = combine_images(generated_images)
  124. image = image* 127.5+ 127.5
  125. Image. fromarray(image. astype(np.uint8)). save(
  126. "generated_image.png")
  127. def get_args():
  128. parser = argparse. ArgumentParser()
  129. parser. add_argument( "--mode", type=str,default = 'train',)
  130. # parser. add_argument( "--mode", type=str,default = 'generate',)
  131. parser. add_argument( "--batch_size", type=int, default= 8)
  132. parse
  133. if __name__ == "__main__":
  134. args = get_args()
  135. if args.mode == "train":
  136. train(BATCH_SIZE=args.batch_size,path =args.path )
  137. elif args.mode == "generate":
  138. generate(BATCH_SIZE=args.batch_size, nice=args.nice)

创作不易 觉得有帮助请点赞关注收藏~~~


转载:https://blog.csdn.net/jiebaoshayebuhui/article/details/128262166
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场