飞道的博客

我用AI给女友画了一幅画,这届算法有点强!

247人阅读  评论(0)

大家好,我是 Jack。

小时候,我其实还是有点艺术细胞的,喜欢看火影忍者和七龙珠的我,虽然没学过绘画,但也笨手笨脚地画了不少作品。

特意叫我妈,把我收藏多年的小破本拿出来,分享下我儿时的快乐。

小学几年级画的记不清了,只记得一画就是小半天,还拿去学校显摆了一番。

如今,再让我拿起铅笔,画个素描,我是画不出来了。

不过,我另辟蹊径,用起了算法。我lbw,没有开挂!

Anime2Sketch

Anime2Sketch 是一个动画、漫画、插画等艺术作品的素描提取器

给我个艺术作品,我直接把它变成素描作品:

 

耗时1秒临摹的素描作品:

 

Anime2Sketch 算法也非常简单,就是一个 UNet 结构,生成素描作品,可以看下它的网络结构:


  
  1. import torch
  2. import torch.nn as nn
  3. import functools
  4. class UnetGenerator(nn.Module):
  5. """Create a Unet-based generator"""
  6. def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
  7. """Construct a Unet generator
  8. Parameters:
  9. input_nc (int) -- the number of channels in input images
  10. output_nc (int) -- the number of channels in output images
  11. num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
  12. image of size 128x128 will become of size 1x1 # at the bottleneck
  13. ngf (int) -- the number of filters in the last conv layer
  14. norm_layer -- normalization layer
  15. We construct the U-Net from the innermost layer to the outermost layer.
  16. It is a recursive process.
  17. """
  18. super(UnetGenerator, self).__init__()
  19. # construct unet structure
  20. unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc= None, submodule= None, norm_layer=norm_layer, innermost= True) # add the innermost layer
  21. for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
  22. unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc= None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
  23. # gradually reduce the number of filters from ngf * 8 to ngf
  24. unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc= None, submodule=unet_block, norm_layer=norm_layer)
  25. unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc= None, submodule=unet_block, norm_layer=norm_layer)
  26. unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc= None, submodule=unet_block, norm_layer=norm_layer)
  27. self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost= True, norm_layer=norm_layer) # add the outermost layer
  28. def forward(self, input):
  29. """Standard forward"""
  30. return self.model(input)
  31. class UnetSkipConnectionBlock(nn.Module):
  32. """Defines the Unet submodule with skip connection.
  33. X -------------------identity----------------------
  34. |-- downsampling -- |submodule| -- upsampling --|
  35. """
  36. def __init__(self, outer_nc, inner_nc, input_nc=None,
  37. submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
  38. """Construct a Unet submodule with skip connections.
  39. Parameters:
  40. outer_nc (int) -- the number of filters in the outer conv layer
  41. inner_nc (int) -- the number of filters in the inner conv layer
  42. input_nc (int) -- the number of channels in input images/features
  43. submodule (UnetSkipConnectionBlock) -- previously defined submodules
  44. outermost (bool) -- if this module is the outermost module
  45. innermost (bool) -- if this module is the innermost module
  46. norm_layer -- normalization layer
  47. use_dropout (bool) -- if use dropout layers.
  48. """
  49. super(UnetSkipConnectionBlock, self).__init__()
  50. self.outermost = outermost
  51. if type(norm_layer) == functools.partial:
  52. use_bias = norm_layer.func == nn.InstanceNorm2d
  53. else:
  54. use_bias = norm_layer == nn.InstanceNorm2d
  55. if input_nc is None:
  56. input_nc = outer_nc
  57. downconv = nn.Conv2d(input_nc, inner_nc, kernel_size= 4,
  58. stride= 2, padding= 1, bias=use_bias)
  59. downrelu = nn.LeakyReLU( 0.2, True)
  60. downnorm = norm_layer(inner_nc)
  61. uprelu = nn.ReLU( True)
  62. upnorm = norm_layer(outer_nc)
  63. if outermost:
  64. upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
  65. kernel_size= 4, stride= 2,
  66. padding= 1)
  67. down = [downconv]
  68. up = [uprelu, upconv, nn.Tanh()]
  69. model = down + [submodule] + up
  70. elif innermost:
  71. upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
  72. kernel_size= 4, stride= 2,
  73. padding= 1, bias=use_bias)
  74. down = [downrelu, downconv]
  75. up = [uprelu, upconv, upnorm]
  76. model = down + up
  77. else:
  78. upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
  79. kernel_size= 4, stride= 2,
  80. padding= 1, bias=use_bias)
  81. down = [downrelu, downconv, downnorm]
  82. up = [uprelu, upconv, upnorm]
  83. if use_dropout:
  84. model = down + [submodule] + up + [nn.Dropout( 0.5)]
  85. else:
  86. model = down + [submodule] + up
  87. self.model = nn.Sequential(*model)
  88. def forward(self, x):
  89. if self.outermost:
  90. return self.model(x)
  91. else: # add skip connections
  92. return torch.cat([x, self.model(x)], 1)
  93. def create_model(gpu_ids=[]):
  94. """Create a model for anime2sketch
  95. hardcoding the options for simplicity
  96. """
  97. norm_layer = functools.partial(nn.InstanceNorm2d, affine= False, track_running_stats= False)
  98. net = UnetGenerator( 3, 1, 8, 64, norm_layer=norm_layer, use_dropout= False)
  99. ckpt = torch.load( 'weights/netG.pth')
  100. for key in list(ckpt.keys()):
  101. if 'module.' in key:
  102. ckpt[key.replace( 'module.', '')] = ckpt[key]
  103. del ckpt[key]
  104. net.load_state_dict(ckpt)
  105. if len(gpu_ids) > 0:
  106. assert(torch.cuda.is_available())
  107. net.to(gpu_ids[ 0])
  108. net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
  109. return net

UNet 应该都很熟悉了,就不多介绍了。

项目地址:https://github.com/Mukosame/Anime2Sketch

环境部署也很简单,只需要安装以下三个库:


  
  1. torch>= 0. 4. 1
  2. torchvision>= 0. 2. 1
  3. Pillow>= 6. 0. 0

然后下载权重文件,即可。

权重文件放在了GoogleDrive,为了方便大家,我将代码和权重文件,还有一些测试图片,都打包好了。

直接下载,即可运行(提取码:a7r4):

https://pan.baidu.com/s/1h6bqgphqUUjj4fz61Y9HCA

进入项目根目录,直接运行命令:

python3 test.py --dataroot test_samples --load_size 512 --output_dir results

运行效果:

“画”得非常快,我在网上找了一些图片进行测试。

鸣人和带土:

柯南和灰原哀:

絮叨

使用算法前:

这样的素描,没有灵魂!

使用算法后:

拿了一些真人的图片进行了测试,发现效果很差,果然真人的线条还是要复杂一些的。

最后再送大家一本,帮助我拿到 BAT 等一线大厂 offer 的数据结构刷题笔记,是一位 Google 大神写的,对于算法薄弱或者需要提高的同学都十分受用(提起码:m19c):

BAT 大佬分类总结的 Leetcode 刷题模版,助你搞定 90% 的面试

以及我整理的 BAT 算法工程师学习路线,书籍+视频,完整的学习路线和说明,对于想成为算法工程师的,绝对能有所帮助(提取码:jack):

我是如何成为算法工程师的,超详细的学习路线

我是 Jack,我们下期见。


转载:https://blog.csdn.net/c406495762/article/details/116760610
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场