飞道的博客

五、肺癌检测-数据集训练 training.py model.py

10471人阅读  评论(0)

上一篇文章中已经通过将dsets.py实现将数据集封装加载,之后就可以通过建立了模型并编写training脚本实现模型的训练了。这一篇文章主要是对《pytorch深度学习实战》第11章内容做的笔记。

一、目标

1、建立简单的卷积神经网络

2、编写训练函数

3、编写训练日志(训练和验证过程的loss,accuracy等)数据结构

4、使用tensorboard可视化训练信息。

二、要点说明

1. 对函数使用通用的系统进程级别的调用

原书代码的【code/p2_run_everything.ipynb】的cell2中,定义了一个通用的系统进程方式的调用方法。通过这种方法可以调用所有脚本中的函数。但个人认为还是挺麻烦的,一点都不人性化。建议不要把精力花在这部分代码上,知道代码是在干嘛就行。


  
  1. def run( app, *argv):
  2. argv = list(argv)
  3. argv.insert( 0, '--num-workers=4') # <1> 使用4个核
  4. log.info( "Running: {}({!r}).main()". format(app, argv))
  5. app_cls = importstr(*app.rsplit( '.', 1)) # <2> # 动态加载库
  6. app_cls(argv).main() # 调用app类的main函数
  7. log.info( "Finished: {}.{!r}).main()". format(app, argv))

使用示例:从p2ch11文件夹的training.py文件中importLunaTrainingApp类并调用其main函数,函数的输入参数是epochs=1。

run('p2ch11.training.LunaTrainingApp', '--epochs=1')

其中:

1.1 importstr函数

函数是为了实现动态调用各个库和库函数。类似于from 【pkg_name】 import 【func_name】的作用。通过importstr可以实现动态加载函数,而不用调用前用import声明。

1.2 rsplit函数

 函数用法:list = str.rsplit(sep, maxsplit)。可参考下面的文章。简单而言就是对字符【str】按照【sep】分隔符进行拆分,从字符右侧开始拆分,一共拆分【maxsplit】次。返回的是拆分结果是一个list。

Python实用语法之rsplit_明 总 有的博客-CSDN博客_python rsplit

1.3 argparse库

在原书代码的【prepcache.py】文件中,使用了argparse库。argparse库是用来解决使用命令行执行函数时,让命令行能够解析我们输入的参数名称和参数值的问题。定义了参数解释器后,我们在命令行执行函数时,就可以像使用conda命令一样,用类似【conda --user xxx】一样的方式来执行函数了。

argparse库的具体用法可以参考以下文章:

argparse.ArgumentParser()的用法_无尽的沉默的博客-CSDN博客_argparse.argumentparser

简单用法如下:


  
  1. import argparse
  2. parser = argparse.ArgumentParser() # 创建一个参数解释器
  3. parser.add_argument( "--arg1", type= int, help= "一个整数", default= 1) # 通过 --argName方式声明参数,为int类型
  4. parser.add_argument( "--arg2", type= int, help= "一个整数", default= 2) # 通过 --argName方式声明参数,为int类型
  5. args = parser.parse_args() # 解析参数
  6. print( "arg1 = {0}". format(args.arg1))
  7. print( "arg2 = {0}". format(args.arg2))

 使用命令行运行结果如下:


  
  1. (pytorch) E:\CT\code>python test2.py --arg1 1 --arg2 2
  2. arg1 = 1
  3. arg2 = 2

1.4  @classmethod修饰器

在原书代码的【prepcache.py】文件中,使用了@classmethod修饰器,这样就可以不实例化对象直接调用类内的函数。

2. 模型建立

书中在11章用的是简单的卷积堆叠+线性层的神经网络结果,没任何特别之处。其中线性层由于只是简单2分类(结节是否为肿瘤),所以只用了一个线性层。卷积和池化用的是3维的卷积和池化。

2.1 多GPU设置

多GPU训练可通过nn.DataParallel(model)或DistributedParallel函数实现,前者较为简单,一般用在单机多卡场景,后者配置较为复杂,一般用在多台计算机的多卡场景。

2.2 优化器

一般开始训练时可以先尝试使用带动量的SGD,lr=0.001,momentum=0.9,不行再换其他优化器,如Adam。

2.3 模型输入尺寸

在上一篇文章中的ct类介绍中,width_irc参数定义了每个在irc坐标系的尺寸大小。也是数据集输入到模型的input_size。

2.4 模型信息

使用torchinfo库或者torchsummary库的summary函数都可以打印模型的参数信息。具体方法如下:


  
  1. from p2ch11.model import LunaModel
  2. import torchinfo # 安装命令conda install torchinfo
  3. model = LunaModel()
  4. torchinfo.summary(model, ( 1, 32, 48, 48), batch_dim= 0,
  5. col_names = ( 'input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 1)

运行结果,即模型信息如下:


  
  1. =====================================================================================================================================================================
  2. Layer ( type:depth-idx) Input Shape Output Shape Param # Kernel Shape Mult-Adds
  3. =====================================================================================================================================================================
  4. LunaModel [ 1, 1, 32, 48, 48] [ 1, 2] -- -- --
  5. ├─BatchNorm3d: 1- 1 [ 1, 1, 32, 48, 48] [ 1, 1, 32, 48, 48] 2 -- 2
  6. ├─LunaBlock: 1- 2 [ 1, 1, 32, 48, 48] [ 1, 8, 16, 24, 24] -- -- --
  7. │ └─Conv3d: 2- 1 [ 1, 1, 32, 48, 48] [ 1, 8, 32, 48, 48] 224 [ 3, 3, 3] 16, 515,072
  8. │ └─ReLU: 2- 2 [ 1, 8, 32, 48, 48] [ 1, 8, 32, 48, 48] -- -- --
  9. │ └─Conv3d: 2- 3 [ 1, 8, 32, 48, 48] [ 1, 8, 32, 48, 48] 1, 736 [ 3, 3, 3] 127, 991, 808
  10. │ └─ReLU: 2- 4 [ 1, 8, 32, 48, 48] [ 1, 8, 32, 48, 48] -- -- --
  11. │ └─MaxPool3d: 2- 5 [ 1, 8, 32, 48, 48] [ 1, 8, 16, 24, 24] -- 2 --
  12. ├─LunaBlock: 1- 3 [ 1, 8, 16, 24, 24] [ 1, 16, 8, 12, 12] -- -- --
  13. │ └─Conv3d: 2- 6 [ 1, 8, 16, 24, 24] [ 1, 16, 16, 24, 24] 3, 472 [ 3, 3, 3] 31, 997, 952
  14. │ └─ReLU: 2- 7 [ 1, 16, 16, 24, 24] [ 1, 16, 16, 24, 24] -- -- --
  15. │ └─Conv3d: 2- 8 [ 1, 16, 16, 24, 24] [ 1, 16, 16, 24, 24] 6, 928 [ 3, 3, 3] 63, 848, 448
  16. │ └─ReLU: 2- 9 [ 1, 16, 16, 24, 24] [ 1, 16, 16, 24, 24] -- -- --
  17. │ └─MaxPool3d: 2- 10 [ 1, 16, 16, 24, 24] [ 1, 16, 8, 12, 12] -- 2 --
  18. ├─LunaBlock: 1- 4 [ 1, 16, 8, 12, 12] [ 1, 32, 4, 6, 6] -- -- --
  19. │ └─Conv3d: 2- 11 [ 1, 16, 8, 12, 12] [ 1, 32, 8, 12, 12] 13, 856 [ 3, 3, 3] 15, 962, 112
  20. │ └─ReLU: 2- 12 [ 1, 32, 8, 12, 12] [ 1, 32, 8, 12, 12] -- -- --
  21. │ └─Conv3d: 2- 13 [ 1, 32, 8, 12, 12] [ 1, 32, 8, 12, 12] 27, 680 [ 3, 3, 3] 31, 887, 360
  22. │ └─ReLU: 2- 14 [ 1, 32, 8, 12, 12] [ 1, 32, 8, 12, 12] -- -- --
  23. │ └─MaxPool3d: 2- 15 [ 1, 32, 8, 12, 12] [ 1, 32, 4, 6, 6] -- 2 --
  24. ├─LunaBlock: 1- 5 [ 1, 32, 4, 6, 6] [ 1, 64, 2, 3, 3] -- -- --
  25. │ └─Conv3d: 2- 16 [ 1, 32, 4, 6, 6] [ 1, 64, 4, 6, 6] 55, 360 [ 3, 3, 3] 7, 971, 840
  26. │ └─ReLU: 2- 17 [ 1, 64, 4, 6, 6] [ 1, 64, 4, 6, 6] -- -- --
  27. │ └─Conv3d: 2- 18 [ 1, 64, 4, 6, 6] [ 1, 64, 4, 6, 6] 110, 656 [ 3, 3, 3] 15, 934, 464
  28. │ └─ReLU: 2- 19 [ 1, 64, 4, 6, 6] [ 1, 64, 4, 6, 6] -- -- --
  29. │ └─MaxPool3d: 2- 20 [ 1, 64, 4, 6, 6] [ 1, 64, 2, 3, 3] -- 2 --
  30. ├─Linear: 1- 6 [ 1, 1152] [ 1, 2] 2, 306 -- 2, 306
  31. ├─Softmax: 1- 7 [ 1, 2] [ 1, 2] -- -- --
  32. =====================================================================================================================================================================
  33. Total params: 222, 220
  34. Trainable params: 222, 220
  35. Non-trainable params: 0
  36. Total mult-adds (M): 312.11
  37. =====================================================================================================================================================================
  38. Input size (MB): 0.29
  39. Forward/backward pass size (MB): 13.12
  40. Params size (MB): 0.89
  41. Estimated Total Size (MB): 14.31
  42. =====================================================================================================================================================================
  43. Process finished with exit code 0

3. 初始化

训练开始前,需要对权重进行初始化,初始化方法是通用的,具体参照书中代码【model.py】的_init_weights函数。


  
  1. def _init_weights( self):
  2. for m in self.modules():
  3. if type(m) in {
  4. nn.Linear,
  5. nn.Conv3d,
  6. nn.Conv2d,
  7. nn.ConvTranspose2d,
  8. nn.ConvTranspose3d,
  9. }:
  10. nn.init.kaiming_normal_(
  11. m.weight.data, a= 0, mode= 'fan_out', nonlinearity= 'relu',
  12. )
  13. if m.bias is not None:
  14. fan_in, fan_out = \
  15. nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  16. bound = 1 / math.sqrt(fan_out)
  17. nn.init.normal_(m.bias, -bound, bound)

4. 代码运行时间预计

原书代码中,定义了enumerateWithEstimate函数来预计运行完某段代码所需的运行时间。其中关键是利用了yield关键字,使enumerateWithEstimate一次次的迭代加载数据集。关于yield的用法可参考下面的文章。

python中yield的用法详解——最简单,最清晰的解释_冯爽朗的博客-CSDN博客_python yield

总的来说,声明为yield关键子的函数func,调用时类似断点执行:

1.首次执行时,代码执行到yield关键字右侧部分代码,并返回右侧部分代码的结果,类似return。yield之后的代码不在执行。

2. 用next函数再次调用函数func时,函数func继续从yield之后的代码开始执行,直到碰到下一个yield;如果函数后续没有别的yield关键字,则函数运行到末尾后返回函数开头重新运行,直至碰到yield。

3. 每次用next函数调用func时,不断重复第2点的执行方式。

5. 提高数据加载速度

原书中,作者通过diskacache库,将第一次加载的数据集缓存到磁盘中,下次训练或者验证再加载数据的时候,可直接在磁盘缓存中加载,可节省极大部分数据加载和预处理的时间。具体diskache库用法可参考下面的文章:

https://blog.csdn.net/wxyczhyza/article/details/127773721

三、代码

原书代码可根据下面文章的代码链接下载,这里贴下我自己注释过的代码吧:

1. 网络模型 model.py

代码如下:


  
  1. import math
  2. from torch import nn as nn
  3. from util.logconf import logging
  4. log = logging.getLogger(__name__)
  5. # log.setLevel(logging.WARN)
  6. # log.setLevel(logging.INFO)
  7. log.setLevel(logging. DEBUG)
  8. class LunaModel(nn.Module):
  9. def __init__(self, in_channels=1, conv_channels=8):
  10. super().__init__()
  11. self.tail_batchnorm = nn.BatchNorm3d(1)
  12. self.block1 = LunaBlock(in_channels, conv_channels)
  13. self.block2 = LunaBlock(conv_channels, conv_channels * 2)
  14. self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
  15. self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
  16. self.head_linear = nn.Linear(1152, 2)
  17. self.head_softmax = nn.Softmax(dim=1)
  18. self._init_weights()
  19. # see also https://github.com/pytorch/pytorch/issues/18182
  20. def _init_weights(self):
  21. for m in self.modules():
  22. if type(m) in {
  23. nn.Linear,
  24. nn.Conv3d,
  25. nn.Conv2d,
  26. nn.ConvTranspose2d,
  27. nn.ConvTranspose3d,
  28. }:
  29. nn.init.kaiming_normal_(
  30. m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
  31. )
  32. if m.bias is not None:
  33. fan_in, fan_out = \
  34. nn.init._calculate_fan_in_and_fan_out(m.weight.data)
  35. bound = 1 / math.sqrt(fan_out)
  36. nn.init.normal_(m.bias, -bound, bound)
  37. def forward(self, input_batch):
  38. bn_output = self.tail_batchnorm(input_batch)
  39. block_out = self.block1(bn_output)
  40. block_out = self.block2(block_out)
  41. block_out = self.block3(block_out)
  42. block_out = self.block4(block_out)
  43. conv_flat = block_out.view(
  44. block_out.size(0),
  45. -1,
  46. )
  47. linear_output = self.head_linear(conv_flat)
  48. return linear_output, self.head_softmax(linear_output)
  49. class LunaBlock(nn.Module):
  50. def __init__(self, in_channels, conv_channels):
  51. super().__init__()
  52. self.conv1 = nn.Conv3d(
  53. in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
  54. )
  55. self.relu1 = nn.ReLU(inplace=True)
  56. self.conv2 = nn.Conv3d(
  57. conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
  58. )
  59. self.relu2 = nn.ReLU(inplace=True)
  60. self.maxpool = nn.MaxPool3d(2, 2)
  61. def forward(self, input_batch):
  62. block_out = self.conv1(input_batch)
  63. block_out = self.relu1(block_out)
  64. block_out = self.conv2(block_out)
  65. block_out = self.relu2(block_out)
  66. return self.maxpool(block_out)

2.  enumerateWithEstimate函数

函数位置:util\util.py

函数主要用了yield关键字,使enumerateWithEstimate函数变为一个迭代器生成器,不断的迭代加载数据集,并根据每次迭代的时间来预估加载完整个数据集所需要的总时间。


  
  1. # 函数实现预估加载完整个迭代器所需要的时间。具体原理:
  2. # step1:使用yield关键字,每次加载一部分数据集,统计这部分数据集的平均单个数据集的使用时间delta_t = 花费的时间/该部分数据集样本数
  3. # step2:根据迭代器长度,预估加载整个数据集所花时间 t_dataset = delta_t * 数据集长度
  4. def enumerateWithEstimate(
  5. iter, # 数据集的一个迭代器。函数目的就是统计加载完整个数据集所需要的时间。
  6. desc_str, # 打印log的时候的说明文本。自己随便定义就行。
  7. start_ndx= 0, # 开始统计前跳过的统计此时。比如start_ndx=3,则意思是第1,2次统计不打印,第三次开始打印。
  8. print_ndx= 4, # 相邻两次打印日志的统计次数间隔print_ndx = print_ndx * backoff,缺省的初始值为4
  9. backoff= None, # 相邻两次打印日志的统计次数间隔的倍数。print_ndx = print_ndx * backoff
  10. iter_len= None, # 迭代器的长度,不指定时,iter_len = len(iter)
  11. ):
  12. """
  13. In terms of behavior, `enumerateWithEstimate` is almost identical
  14. to the standard `enumerate` (the differences are things like how
  15. our function returns a generator, while `enumerate` returns a
  16. specialized `<enumerate object at 0x...>`).
  17. However, the side effects (logging, specifically) are what make the
  18. function interesting.
  19. :param iter: `iter` is the iterable that will be passed into
  20. `enumerate`. Required.
  21. :param desc_str: This is a human-readable string that describes
  22. what the loop is doing. The value is arbitrary, but should be
  23. kept reasonably short. Things like `"epoch 4 training"` or
  24. `"deleting temp files"` or similar would all make sense.
  25. :param start_ndx: This parameter defines how many iterations of the
  26. loop should be skipped before timing actually starts. Skipping
  27. a few iterations can be useful if there are startup costs like
  28. caching that are only paid early on, resulting in a skewed
  29. average when those early iterations dominate the average time
  30. per iteration.
  31. NOTE: Using `start_ndx` to skip some iterations makes the time
  32. spent performing those iterations not be included in the
  33. displayed duration. Please account for this if you use the
  34. displayed duration for anything formal.
  35. This parameter defaults to `0`.
  36. :param print_ndx: determines which loop interation that the timing
  37. logging will start on. The intent is that we don't start
  38. logging until we've given the loop a few iterations to let the
  39. average time-per-iteration a chance to stablize a bit. We
  40. require that `print_ndx` not be less than `start_ndx` times
  41. `backoff`, since `start_ndx` greater than `0` implies that the
  42. early N iterations are unstable from a timing perspective.
  43. `print_ndx` defaults to `4`.
  44. :param backoff: This is used to how many iterations to skip before
  45. logging again. Frequent logging is less interesting later on,
  46. so by default we double the gap between logging messages each
  47. time after the first.
  48. `backoff` defaults to `2` unless iter_len is > 1000, in which
  49. case it defaults to `4`.
  50. :param iter_len: Since we need to know the number of items to
  51. estimate when the loop will finish, that can be provided by
  52. passing in a value for `iter_len`. If a value isn't provided,
  53. then it will be set by using the value of `len(iter)`.
  54. :return:
  55. """
  56. if iter_len is None:
  57. iter_len = len( iter)
  58. if backoff is None:
  59. backoff = 2
  60. while backoff ** 7 < iter_len:
  61. backoff *= 2
  62. assert backoff >= 2
  63. while print_ndx < start_ndx * backoff:
  64. print_ndx *= backoff
  65. log.warning( "{} ----/{}, starting". format(
  66. desc_str,
  67. iter_len,
  68. ))
  69. start_ts = time.time()
  70. for (current_ndx, item) in enumerate( iter):
  71. yield (current_ndx, item)
  72. if current_ndx == print_ndx:
  73. # ... <1> step1:计算若干隔数据集加载时间;step2:平均得到每个数据集加载时间;step3:乘以数据集长度得到预计加载所有数据的时间
  74. duration_sec = ((time.time() - start_ts)
  75. / (current_ndx - start_ndx + 1)
  76. * (iter_len-start_ndx)
  77. )
  78. done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
  79. done_td = datetime.timedelta(seconds=duration_sec)
  80. log.info( "{} {:-4}/{}, done at {}, {}". format(
  81. desc_str,
  82. current_ndx,
  83. iter_len,
  84. str(done_dt).rsplit( '.', 1)[ 0], # 运行了current_ndx次后,预估的加载完整个数据集后的系统时间
  85. str(done_td).rsplit( '.', 1)[ 0], # 运行了current_ndx次后,预估的加载完整个数据集所需要的秒数
  86. ))
  87. print_ndx *= backoff
  88. if current_ndx + 1 == start_ndx:
  89. start_ts = time.time()
  90. log.warning( "{} ----/{}, done at {}". format(
  91. desc_str,
  92. iter_len,
  93. str(datetime.datetime.now()).rsplit( '.', 1)[ 0],
  94. ))

3. prepcahe.py

这个脚本用来尝试加载整个数据集,测试加载数据集所需要的时间。核心时调用enumerateWithEstimate函数。


  
  1. import argparse # 参数解释器
  2. import sys
  3. import numpy as np
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch.optim import SGD
  7. from torch.utils.data import DataLoader
  8. from util.util import enumerateWithEstimate
  9. from .dsets import LunaDataset
  10. from util.logconf import logging
  11. from .model import LunaModel
  12. log = logging.getLogger(__name__)
  13. # log.setLevel(logging.WARN)
  14. log.setLevel(logging.INFO)
  15. # log.setLevel(logging.DEBUG)
  16. class LunaPrepCacheApp:
  17. @classmethod
  18. def __init__( self, sys_argv=None):
  19. if sys_argv is None:
  20. sys_argv = sys.argv[ 1:]
  21. parser = argparse.ArgumentParser() # 命令行参数修饰器
  22. parser.add_argument( '--batch-size', # 添加参数
  23. help= 'Batch size to use for training',
  24. default= 1024,
  25. type= int,
  26. )
  27. parser.add_argument( '--num-workers',
  28. help= 'Number of worker processes for background data loading',
  29. default= 8,
  30. type= int,
  31. )
  32. self.cli_args = parser.parse_args(sys_argv) # 解释参数
  33. def main( self):
  34. log.info( "Starting {}, {}". format( type(self).__name__, self.cli_args))
  35. self.prep_dl = DataLoader(
  36. LunaDataset(
  37. sortby_str= 'series_uid',
  38. ),
  39. batch_size=self.cli_args.batch_size,
  40. num_workers=self.cli_args.num_workers,
  41. )
  42. batch_iter = enumerateWithEstimate( # 尝试加载数据集,预估加载整个数据集所需时间
  43. self.prep_dl,
  44. "Stuffing cache",
  45. start_ndx=self.prep_dl.num_workers,
  46. )
  47. for _ in batch_iter:
  48. pass
  49. if __name__ == '__main__':
  50. LunaPrepCacheApp().main() # 对类的__init__函数使用了@classmethod修饰器,所以可以不需要实例化,直接调用类内函数

在jupyter运行方法可参考原书代码的【p2_run_everything.ipynb】的【chapter11-cell2】。具体运行方法:

step1:加载相关库和函数

step2:使用命令行形式调用LunaPrepCacheApp函数。

 运行结果:

从下图可见,数据集中一个551065个样本,每个batch有1024个样本,一共539个batch,加载16个batch后,推算出加载完所有batch的时间要6个小时05分。

4. training.py

注释了部分代码,其中部分关于tensorboard的代码注释放到第六篇文章的笔记。训练结果及代码如下:


  
  1. import argparse
  2. import datetime
  3. import os
  4. import sys
  5. import numpy as np
  6. from torch.utils.tensorboard import SummaryWriter
  7. import torch
  8. import torch.nn as nn
  9. from torch.optim import SGD, Adam
  10. from torch.utils.data import DataLoader
  11. from util.util import enumerateWithEstimate
  12. from .dsets import LunaDataset
  13. from util.logconf import logging
  14. from .model import LunaModel
  15. log = logging.getLogger(__name__)
  16. # log.setLevel(logging.WARN)
  17. log.setLevel(logging.INFO)
  18. log.setLevel(logging.DEBUG)
  19. # Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
  20. # 将每个样本在训练时候的label、预测值、loss存在了一个矩阵,用于打印结果和tensorboard上显示
  21. # 矩阵第一行为label,第二行为预测值,第三行为loss值,每一列为一个样本
  22. METRICS_LABEL_NDX= 0 # label的行索引
  23. METRICS_PRED_NDX= 1 # 预测值行索引
  24. METRICS_LOSS_NDX= 2 # loss值行索引
  25. METRICS_SIZE = 3 # 矩阵行数
  26. class LunaTrainingApp:
  27. def __init__( self, sys_argv=None):
  28. if sys_argv is None:
  29. sys_argv = sys.argv[ 1:]
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument( '--num-workers',
  32. help= 'Number of worker processes for background data loading',
  33. default= 6, # 使用的CPU核心数,我用的i5-12490f为6核
  34. type= int,
  35. )
  36. parser.add_argument( '--batch-size',
  37. help= 'Batch size to use for training',
  38. default= 24, # 每个batch样本数
  39. type= int,
  40. )
  41. parser.add_argument( '--epochs',
  42. help= 'Number of epochs to train for',
  43. default= 1, # 训练的代数
  44. type= int,
  45. )
  46. parser.add_argument( '--tb-prefix',
  47. default= 'p2ch11',
  48. help= "Data prefix to use for Tensorboard run. Defaults to chapter.",
  49. )
  50. parser.add_argument( 'comment',
  51. help= "Comment suffix for Tensorboard run.",
  52. nargs= '?',
  53. default= 'dwlpt',
  54. )
  55. self.cli_args = parser.parse_args(sys_argv)
  56. self.time_str = datetime.datetime.now().strftime( '%Y-%m-%d_%H.%M.%S')
  57. self.trn_writer = None
  58. self.val_writer = None
  59. self.totalTrainingSamples_count = 0
  60. self.use_cuda = torch.cuda.is_available()
  61. self.device = torch.device( "cuda" if self.use_cuda else "cpu")
  62. self.model = self.initModel() # 将模型搬到cuda
  63. self.optimizer = self.initOptimizer() # 定义优化器
  64. def initModel( self):
  65. model = LunaModel()
  66. if self.use_cuda:
  67. log.info( "Using CUDA; {} devices.". format(torch.cuda.device_count()))
  68. if torch.cuda.device_count() > 1:
  69. model = nn.DataParallel(model) # 如果有多个gpu,分配多给GPU训练
  70. model = model.to(self.device)
  71. return model
  72. def initOptimizer( self):
  73. # 一般第一次训练用SGD看看效果,再选择其他优化器。比较常用参数为lr=0.001,momentum=0.99
  74. return SGD(self.model.parameters(), lr= 0.001, momentum= 0.99)
  75. # return Adam(self.model.parameters())
  76. def initTrainDl( self):
  77. # 由于LunaDataset的getCtRawCandidate被diskcache修饰,所以第一次加载数据集时,需要从文件读取数据,
  78. # 同时数据处理后会缓存到磁盘,速度较慢;第二次开始,会直接从缓存加载,速度会较快。
  79. train_ds = LunaDataset(
  80. val_stride= 10,
  81. isValSet_bool= False,
  82. )
  83. batch_size = self.cli_args.batch_size
  84. if self.use_cuda:
  85. batch_size *= torch.cuda.device_count()
  86. train_dl = DataLoader(
  87. train_ds,
  88. batch_size=batch_size,
  89. num_workers=self.cli_args.num_workers,
  90. pin_memory=self.use_cuda,
  91. )
  92. return train_dl
  93. def initValDl( self):
  94. val_ds = LunaDataset(
  95. val_stride= 10,
  96. isValSet_bool= True,
  97. )
  98. batch_size = self.cli_args.batch_size
  99. if self.use_cuda:
  100. batch_size *= torch.cuda.device_count()
  101. val_dl = DataLoader(
  102. val_ds,
  103. batch_size=batch_size,
  104. num_workers=self.cli_args.num_workers,
  105. pin_memory=self.use_cuda,
  106. )
  107. return val_dl
  108. def initTensorboardWriters( self):
  109. if self.trn_writer is None:
  110. log_dir = os.path.join( 'runs', self.cli_args.tb_prefix, self.time_str)
  111. self.trn_writer = SummaryWriter(
  112. log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
  113. self.val_writer = SummaryWriter(
  114. log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
  115. def main( self):
  116. log.info( "Starting {}, {}". format( type(self).__name__, self.cli_args))
  117. train_dl = self.initTrainDl()
  118. val_dl = self.initValDl()
  119. for epoch_ndx in range( 1, self.cli_args.epochs + 1):
  120. log.info( "Epoch {} of {}, {}/{} batches of size {}*{}". format(
  121. epoch_ndx,
  122. self.cli_args.epochs,
  123. len(train_dl),
  124. len(val_dl),
  125. self.cli_args.batch_size,
  126. (torch.cuda.device_count() if self.use_cuda else 1),
  127. ))
  128. trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
  129. self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
  130. valMetrics_t = self.doValidation(epoch_ndx, val_dl)
  131. self.logMetrics(epoch_ndx, 'val', valMetrics_t)
  132. if hasattr(self, 'trn_writer'):
  133. self.trn_writer.close()
  134. self.val_writer.close()
  135. def doTraining( self, epoch_ndx, train_dl):
  136. self.model.train()
  137. trnMetrics_g = torch.zeros(
  138. METRICS_SIZE,
  139. len(train_dl.dataset),
  140. device=self.device,
  141. )
  142. # batch_iter = enumerateWithEstimate(
  143. # train_dl,
  144. # "E{} Training".format(epoch_ndx),
  145. # start_ndx=train_dl.num_workers,
  146. # )
  147. for batch_ndx, batch_tup in enumerate(train_dl):
  148. self.optimizer.zero_grad()
  149. loss_var = self.computeBatchLoss(
  150. batch_ndx,
  151. batch_tup,
  152. train_dl.batch_size,
  153. trnMetrics_g
  154. )
  155. loss_var.backward()
  156. self.optimizer.step()
  157. # # This is for adding the model graph to TensorBoard.
  158. # if epoch_ndx == 1 and batch_ndx == 0:
  159. # with torch.no_grad():
  160. # model = LunaModel()
  161. # self.trn_writer.add_graph(model, batch_tup[0], verbose=True)
  162. # self.trn_writer.close()
  163. self.totalTrainingSamples_count += len(train_dl.dataset)
  164. return trnMetrics_g.to( 'cpu')
  165. def doValidation( self, epoch_ndx, val_dl):
  166. with torch.no_grad():
  167. self.model. eval()
  168. valMetrics_g = torch.zeros(
  169. METRICS_SIZE,
  170. len(val_dl.dataset),
  171. device=self.device,
  172. )
  173. batch_iter = enumerateWithEstimate(
  174. val_dl,
  175. "E{} Validation ". format(epoch_ndx),
  176. start_ndx=val_dl.num_workers,
  177. )
  178. for batch_ndx, batch_tup in batch_iter:
  179. self.computeBatchLoss(
  180. batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
  181. return valMetrics_g.to( 'cpu')
  182. def computeBatchLoss( self, batch_ndx, batch_tup, batch_size, metrics_g):
  183. input_t, label_t, _series_list, _center_list = batch_tup
  184. input_g = input_t.to(self.device, non_blocking= True)
  185. label_g = label_t.to(self.device, non_blocking= True)
  186. logits_g, probability_g = self.model(input_g)
  187. loss_func = nn.CrossEntropyLoss(reduction= 'none') # reduction=none时,将每个样本的loss返回
  188. loss_g = loss_func(
  189. logits_g,
  190. label_g[:, 1],
  191. )
  192. start_ndx = batch_ndx * batch_size
  193. end_ndx = start_ndx + label_t.size( 0)
  194. # 将训练结果存到矩阵
  195. metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = \
  196. label_g[:, 1].detach()
  197. metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = \
  198. probability_g[:, 1].detach()
  199. metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = \
  200. loss_g.detach()
  201. return loss_g.mean()
  202. def logMetrics(
  203. self,
  204. epoch_ndx,
  205. mode_str,
  206. metrics_t,
  207. classificationThreshold=0.5,
  208. ):
  209. self.initTensorboardWriters()
  210. log.info( "E{} {}". format(
  211. epoch_ndx,
  212. type(self).__name__,
  213. ))
  214. negLabel_mask = metrics_t[METRICS_LABEL_NDX] <= classificationThreshold
  215. negPred_mask = metrics_t[METRICS_PRED_NDX] <= classificationThreshold
  216. posLabel_mask = ~negLabel_mask
  217. posPred_mask = ~negPred_mask
  218. neg_count = int(negLabel_mask. sum())
  219. pos_count = int(posLabel_mask. sum())
  220. neg_correct = int((negLabel_mask & negPred_mask). sum())
  221. pos_correct = int((posLabel_mask & posPred_mask). sum())
  222. metrics_dict = {}
  223. metrics_dict[ 'loss/all'] = \
  224. metrics_t[METRICS_LOSS_NDX].mean()
  225. metrics_dict[ 'loss/neg'] = \
  226. metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
  227. metrics_dict[ 'loss/pos'] = \
  228. metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
  229. metrics_dict[ 'correct/all'] = (pos_correct + neg_correct) \
  230. / np.float32(metrics_t.shape[ 1]) * 100
  231. metrics_dict[ 'correct/neg'] = neg_correct / np.float32(neg_count) * 100
  232. metrics_dict[ 'correct/pos'] = pos_correct / np.float32(pos_count) * 100
  233. log.info(
  234. ( "E{} {:8} {loss/all:.4f} loss, "
  235. + "{correct/all:-5.1f}% correct, "
  236. ). format(
  237. epoch_ndx,
  238. mode_str,
  239. **metrics_dict,
  240. )
  241. )
  242. log.info(
  243. ( "E{} {:8} {loss/neg:.4f} loss, "
  244. + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
  245. ). format(
  246. epoch_ndx,
  247. mode_str + '_neg',
  248. neg_correct=neg_correct,
  249. neg_count=neg_count,
  250. **metrics_dict,
  251. )
  252. )
  253. log.info(
  254. ( "E{} {:8} {loss/pos:.4f} loss, "
  255. + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
  256. ). format(
  257. epoch_ndx,
  258. mode_str + '_pos',
  259. pos_correct=pos_correct,
  260. pos_count=pos_count,
  261. **metrics_dict,
  262. )
  263. )
  264. writer = getattr(self, mode_str + '_writer')
  265. for key, value in metrics_dict.items():
  266. writer.add_scalar(key, value, self.totalTrainingSamples_count)
  267. writer.add_pr_curve(
  268. 'pr',
  269. metrics_t[METRICS_LABEL_NDX],
  270. metrics_t[METRICS_PRED_NDX],
  271. self.totalTrainingSamples_count,
  272. )
  273. bins = [x/ 50.0 for x in range( 51)]
  274. negHist_mask = negLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
  275. posHist_mask = posLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)
  276. if negHist_mask. any():
  277. writer.add_histogram(
  278. 'is_neg',
  279. metrics_t[METRICS_PRED_NDX, negHist_mask],
  280. self.totalTrainingSamples_count,
  281. bins=bins,
  282. )
  283. if posHist_mask. any():
  284. writer.add_histogram(
  285. 'is_pos',
  286. metrics_t[METRICS_PRED_NDX, posHist_mask],
  287. self.totalTrainingSamples_count,
  288. bins=bins,
  289. )
  290. # score = 1 \
  291. # + metrics_dict['pr/f1_score'] \
  292. # - metrics_dict['loss/mal'] * 0.01 \
  293. # - metrics_dict['loss/all'] * 0.0001
  294. #
  295. # return score
  296. # def logModelMetrics(self, model):
  297. # writer = getattr(self, 'trn_writer')
  298. #
  299. # model = getattr(model, 'module', model)
  300. #
  301. # for name, param in model.named_parameters():
  302. # if param.requires_grad:
  303. # min_data = float(param.data.min())
  304. # max_data = float(param.data.max())
  305. # max_extent = max(abs(min_data), abs(max_data))
  306. #
  307. # # bins = [x/50*max_extent for x in range(-50, 51)]
  308. #
  309. # try:
  310. # writer.add_histogram(
  311. # name.rsplit('.', 1)[-1] + '/' + name,
  312. # param.data.cpu().numpy(),
  313. # # metrics_a[METRICS_PRED_NDX, negHist_mask],
  314. # self.totalTrainingSamples_count,
  315. # # bins=bins,
  316. # )
  317. # except Exception as e:
  318. # log.error([min_data, max_data])
  319. # raise
  320. if __name__ == '__main__':
  321. LunaTrainingApp().main()


转载:https://blog.csdn.net/wxyczhyza/article/details/127684935
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场