上一篇文章中已经通过将dsets.py实现将数据集封装加载,之后就可以通过建立了模型并编写training脚本实现模型的训练了。这一篇文章主要是对《pytorch深度学习实战》第11章内容做的笔记。
一、目标
1、建立简单的卷积神经网络
2、编写训练函数
3、编写训练日志(训练和验证过程的loss,accuracy等)数据结构
4、使用tensorboard可视化训练信息。
二、要点说明
1. 对函数使用通用的系统进程级别的调用
原书代码的【code/p2_run_everything.ipynb】的cell2中,定义了一个通用的系统进程方式的调用方法。通过这种方法可以调用所有脚本中的函数。但个人认为还是挺麻烦的,一点都不人性化。建议不要把精力花在这部分代码上,知道代码是在干嘛就行。
-
def
run(
app, *argv):
-
argv =
list(argv)
-
argv.insert(
0,
'--num-workers=4')
# <1> 使用4个核
-
log.info(
"Running: {}({!r}).main()".
format(app, argv))
-
-
app_cls = importstr(*app.rsplit(
'.',
1))
# <2> # 动态加载库
-
app_cls(argv).main()
# 调用app类的main函数
-
-
log.info(
"Finished: {}.{!r}).main()".
format(app, argv))
使用示例:从p2ch11文件夹的training.py文件中importLunaTrainingApp类并调用其main函数,函数的输入参数是epochs=1。
run('p2ch11.training.LunaTrainingApp', '--epochs=1')
其中:
1.1 importstr函数
函数是为了实现动态调用各个库和库函数。类似于from 【pkg_name】 import 【func_name】的作用。通过importstr可以实现动态加载函数,而不用调用前用import声明。
1.2 rsplit函数
函数用法:list = str.rsplit(sep, maxsplit)。可参考下面的文章。简单而言就是对字符【str】按照【sep】分隔符进行拆分,从字符右侧开始拆分,一共拆分【maxsplit】次。返回的是拆分结果是一个list。
Python实用语法之rsplit_明 总 有的博客-CSDN博客_python rsplit
1.3 argparse库
在原书代码的【prepcache.py】文件中,使用了argparse库。argparse库是用来解决使用命令行执行函数时,让命令行能够解析我们输入的参数名称和参数值的问题。定义了参数解释器后,我们在命令行执行函数时,就可以像使用conda命令一样,用类似【conda --user xxx】一样的方式来执行函数了。
argparse库的具体用法可以参考以下文章:
argparse.ArgumentParser()的用法_无尽的沉默的博客-CSDN博客_argparse.argumentparser
简单用法如下:
-
import argparse
-
-
parser = argparse.ArgumentParser()
# 创建一个参数解释器
-
parser.add_argument(
"--arg1",
type=
int,
help=
"一个整数", default=
1)
# 通过 --argName方式声明参数,为int类型
-
parser.add_argument(
"--arg2",
type=
int,
help=
"一个整数", default=
2)
# 通过 --argName方式声明参数,为int类型
-
-
args = parser.parse_args()
# 解析参数
-
-
print(
"arg1 = {0}".
format(args.arg1))
-
print(
"arg2 = {0}".
format(args.arg2))
使用命令行运行结果如下:
-
(pytorch) E:\CT\code>python test2.py --arg1
1 --arg2
2
-
arg1 =
1
-
arg2 =
2
1.4 @classmethod修饰器
在原书代码的【prepcache.py】文件中,使用了@classmethod修饰器,这样就可以不实例化对象直接调用类内的函数。
2. 模型建立
书中在11章用的是简单的卷积堆叠+线性层的神经网络结果,没任何特别之处。其中线性层由于只是简单2分类(结节是否为肿瘤),所以只用了一个线性层。卷积和池化用的是3维的卷积和池化。
2.1 多GPU设置
多GPU训练可通过nn.DataParallel(model)或DistributedParallel函数实现,前者较为简单,一般用在单机多卡场景,后者配置较为复杂,一般用在多台计算机的多卡场景。
2.2 优化器
一般开始训练时可以先尝试使用带动量的SGD,lr=0.001,momentum=0.9,不行再换其他优化器,如Adam。
2.3 模型输入尺寸
在上一篇文章中的ct类介绍中,width_irc参数定义了每个在irc坐标系的尺寸大小。也是数据集输入到模型的input_size。
2.4 模型信息
使用torchinfo库或者torchsummary库的summary函数都可以打印模型的参数信息。具体方法如下:
-
from p2ch11.model
import LunaModel
-
import torchinfo
# 安装命令conda install torchinfo
-
-
model = LunaModel()
-
torchinfo.summary(model, (
1,
32,
48,
48), batch_dim=
0,
-
col_names = (
'input_size',
'output_size',
'num_params',
'kernel_size',
'mult_adds'), verbose =
1)
运行结果,即模型信息如下:
-
=====================================================================================================================================================================
-
Layer (
type:depth-idx) Input Shape Output Shape Param
# Kernel Shape Mult-Adds
-
=====================================================================================================================================================================
-
LunaModel [
1,
1,
32,
48,
48] [
1,
2] -- -- --
-
├─BatchNorm3d:
1-
1 [
1,
1,
32,
48,
48] [
1,
1,
32,
48,
48]
2 --
2
-
├─LunaBlock:
1-
2 [
1,
1,
32,
48,
48] [
1,
8,
16,
24,
24] -- -- --
-
│ └─Conv3d:
2-
1 [
1,
1,
32,
48,
48] [
1,
8,
32,
48,
48]
224 [
3,
3,
3]
16,
515,072
-
│ └─ReLU:
2-
2 [
1,
8,
32,
48,
48] [
1,
8,
32,
48,
48] -- -- --
-
│ └─Conv3d:
2-
3 [
1,
8,
32,
48,
48] [
1,
8,
32,
48,
48]
1,
736 [
3,
3,
3]
127,
991,
808
-
│ └─ReLU:
2-
4 [
1,
8,
32,
48,
48] [
1,
8,
32,
48,
48] -- -- --
-
│ └─MaxPool3d:
2-
5 [
1,
8,
32,
48,
48] [
1,
8,
16,
24,
24] --
2 --
-
├─LunaBlock:
1-
3 [
1,
8,
16,
24,
24] [
1,
16,
8,
12,
12] -- -- --
-
│ └─Conv3d:
2-
6 [
1,
8,
16,
24,
24] [
1,
16,
16,
24,
24]
3,
472 [
3,
3,
3]
31,
997,
952
-
│ └─ReLU:
2-
7 [
1,
16,
16,
24,
24] [
1,
16,
16,
24,
24] -- -- --
-
│ └─Conv3d:
2-
8 [
1,
16,
16,
24,
24] [
1,
16,
16,
24,
24]
6,
928 [
3,
3,
3]
63,
848,
448
-
│ └─ReLU:
2-
9 [
1,
16,
16,
24,
24] [
1,
16,
16,
24,
24] -- -- --
-
│ └─MaxPool3d:
2-
10 [
1,
16,
16,
24,
24] [
1,
16,
8,
12,
12] --
2 --
-
├─LunaBlock:
1-
4 [
1,
16,
8,
12,
12] [
1,
32,
4,
6,
6] -- -- --
-
│ └─Conv3d:
2-
11 [
1,
16,
8,
12,
12] [
1,
32,
8,
12,
12]
13,
856 [
3,
3,
3]
15,
962,
112
-
│ └─ReLU:
2-
12 [
1,
32,
8,
12,
12] [
1,
32,
8,
12,
12] -- -- --
-
│ └─Conv3d:
2-
13 [
1,
32,
8,
12,
12] [
1,
32,
8,
12,
12]
27,
680 [
3,
3,
3]
31,
887,
360
-
│ └─ReLU:
2-
14 [
1,
32,
8,
12,
12] [
1,
32,
8,
12,
12] -- -- --
-
│ └─MaxPool3d:
2-
15 [
1,
32,
8,
12,
12] [
1,
32,
4,
6,
6] --
2 --
-
├─LunaBlock:
1-
5 [
1,
32,
4,
6,
6] [
1,
64,
2,
3,
3] -- -- --
-
│ └─Conv3d:
2-
16 [
1,
32,
4,
6,
6] [
1,
64,
4,
6,
6]
55,
360 [
3,
3,
3]
7,
971,
840
-
│ └─ReLU:
2-
17 [
1,
64,
4,
6,
6] [
1,
64,
4,
6,
6] -- -- --
-
│ └─Conv3d:
2-
18 [
1,
64,
4,
6,
6] [
1,
64,
4,
6,
6]
110,
656 [
3,
3,
3]
15,
934,
464
-
│ └─ReLU:
2-
19 [
1,
64,
4,
6,
6] [
1,
64,
4,
6,
6] -- -- --
-
│ └─MaxPool3d:
2-
20 [
1,
64,
4,
6,
6] [
1,
64,
2,
3,
3] --
2 --
-
├─Linear:
1-
6 [
1,
1152] [
1,
2]
2,
306 --
2,
306
-
├─Softmax:
1-
7 [
1,
2] [
1,
2] -- -- --
-
=====================================================================================================================================================================
-
Total params:
222,
220
-
Trainable params:
222,
220
-
Non-trainable params:
0
-
Total mult-adds (M):
312.11
-
=====================================================================================================================================================================
-
Input size (MB):
0.29
-
Forward/backward
pass size (MB):
13.12
-
Params size (MB):
0.89
-
Estimated Total Size (MB):
14.31
-
=====================================================================================================================================================================
-
-
Process finished
with exit code
0
3. 初始化
训练开始前,需要对权重进行初始化,初始化方法是通用的,具体参照书中代码【model.py】的_init_weights函数。
-
def
_init_weights(
self):
-
for m
in self.modules():
-
if
type(m)
in {
-
nn.Linear,
-
nn.Conv3d,
-
nn.Conv2d,
-
nn.ConvTranspose2d,
-
nn.ConvTranspose3d,
-
}:
-
nn.init.kaiming_normal_(
-
m.weight.data, a=
0, mode=
'fan_out', nonlinearity=
'relu',
-
)
-
if m.bias
is
not
None:
-
fan_in, fan_out = \
-
nn.init._calculate_fan_in_and_fan_out(m.weight.data)
-
bound =
1 / math.sqrt(fan_out)
-
nn.init.normal_(m.bias, -bound, bound)
4. 代码运行时间预计
原书代码中,定义了enumerateWithEstimate函数来预计运行完某段代码所需的运行时间。其中关键是利用了yield关键字,使enumerateWithEstimate一次次的迭代加载数据集。关于yield的用法可参考下面的文章。
python中yield的用法详解——最简单,最清晰的解释_冯爽朗的博客-CSDN博客_python yield
总的来说,声明为yield关键子的函数func,调用时类似断点执行:
1.首次执行时,代码执行到yield关键字右侧部分代码,并返回右侧部分代码的结果,类似return。yield之后的代码不在执行。
2. 用next函数再次调用函数func时,函数func继续从yield之后的代码开始执行,直到碰到下一个yield;如果函数后续没有别的yield关键字,则函数运行到末尾后返回函数开头重新运行,直至碰到yield。
3. 每次用next函数调用func时,不断重复第2点的执行方式。
5. 提高数据加载速度
原书中,作者通过diskacache库,将第一次加载的数据集缓存到磁盘中,下次训练或者验证再加载数据的时候,可直接在磁盘缓存中加载,可节省极大部分数据加载和预处理的时间。具体diskache库用法可参考下面的文章:
https://blog.csdn.net/wxyczhyza/article/details/127773721
三、代码
原书代码可根据下面文章的代码链接下载,这里贴下我自己注释过的代码吧:
1. 网络模型 model.py
代码如下:
-
import math
-
-
from torch
import nn
as nn
-
-
from util.logconf
import logging
-
-
log = logging.getLogger(__name__)
-
# log.setLevel(logging.WARN)
-
# log.setLevel(logging.INFO)
-
log.setLevel(logging.
DEBUG)
-
-
-
class LunaModel(nn.Module):
-
def __init__(self, in_channels=1, conv_channels=8):
-
super().__init__()
-
-
self.tail_batchnorm = nn.BatchNorm3d(1)
-
-
self.block1 = LunaBlock(in_channels, conv_channels)
-
self.block2 = LunaBlock(conv_channels, conv_channels * 2)
-
self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
-
self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
-
-
self.head_linear = nn.Linear(1152, 2)
-
self.head_softmax = nn.Softmax(dim=1)
-
-
self._init_weights()
-
-
# see also https://github.com/pytorch/pytorch/issues/18182
-
def _init_weights(self):
-
for m in self.modules():
-
if type(m) in {
-
nn.Linear,
-
nn.Conv3d,
-
nn.Conv2d,
-
nn.ConvTranspose2d,
-
nn.ConvTranspose3d,
-
}:
-
nn.init.kaiming_normal_(
-
m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
-
)
-
if m.bias is not None:
-
fan_in, fan_out = \
-
nn.init._calculate_fan_in_and_fan_out(m.weight.data)
-
bound = 1 / math.sqrt(fan_out)
-
nn.init.normal_(m.bias, -bound, bound)
-
-
-
-
def forward(self, input_batch):
-
bn_output = self.tail_batchnorm(input_batch)
-
-
block_out = self.block1(bn_output)
-
block_out = self.block2(block_out)
-
block_out = self.block3(block_out)
-
block_out = self.block4(block_out)
-
-
conv_flat = block_out.view(
-
block_out.size(0),
-
-1,
-
)
-
linear_output = self.head_linear(conv_flat)
-
-
return linear_output, self.head_softmax(linear_output)
-
-
-
class LunaBlock(nn.Module):
-
def __init__(self, in_channels, conv_channels):
-
super().__init__()
-
-
self.conv1 = nn.Conv3d(
-
in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
-
)
-
self.relu1 = nn.ReLU(inplace=True)
-
self.conv2 = nn.Conv3d(
-
conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
-
)
-
self.relu2 = nn.ReLU(inplace=True)
-
-
self.maxpool = nn.MaxPool3d(2, 2)
-
-
def forward(self, input_batch):
-
block_out = self.conv1(input_batch)
-
block_out = self.relu1(block_out)
-
block_out = self.conv2(block_out)
-
block_out = self.relu2(block_out)
-
-
return self.maxpool(block_out)
-
2. enumerateWithEstimate函数
函数位置:util\util.py
函数主要用了yield关键字,使enumerateWithEstimate函数变为一个迭代器生成器,不断的迭代加载数据集,并根据每次迭代的时间来预估加载完整个数据集所需要的总时间。
-
# 函数实现预估加载完整个迭代器所需要的时间。具体原理:
-
# step1:使用yield关键字,每次加载一部分数据集,统计这部分数据集的平均单个数据集的使用时间delta_t = 花费的时间/该部分数据集样本数
-
# step2:根据迭代器长度,预估加载整个数据集所花时间 t_dataset = delta_t * 数据集长度
-
def
enumerateWithEstimate(
-
iter,
# 数据集的一个迭代器。函数目的就是统计加载完整个数据集所需要的时间。
-
desc_str,
# 打印log的时候的说明文本。自己随便定义就行。
-
start_ndx=
0,
# 开始统计前跳过的统计此时。比如start_ndx=3,则意思是第1,2次统计不打印,第三次开始打印。
-
print_ndx=
4,
# 相邻两次打印日志的统计次数间隔print_ndx = print_ndx * backoff,缺省的初始值为4
-
backoff=
None,
# 相邻两次打印日志的统计次数间隔的倍数。print_ndx = print_ndx * backoff
-
iter_len=
None,
# 迭代器的长度,不指定时,iter_len = len(iter)
-
):
-
"""
-
In terms of behavior, `enumerateWithEstimate` is almost identical
-
to the standard `enumerate` (the differences are things like how
-
our function returns a generator, while `enumerate` returns a
-
specialized `<enumerate object at 0x...>`).
-
-
However, the side effects (logging, specifically) are what make the
-
function interesting.
-
-
:param iter: `iter` is the iterable that will be passed into
-
`enumerate`. Required.
-
-
:param desc_str: This is a human-readable string that describes
-
what the loop is doing. The value is arbitrary, but should be
-
kept reasonably short. Things like `"epoch 4 training"` or
-
`"deleting temp files"` or similar would all make sense.
-
-
:param start_ndx: This parameter defines how many iterations of the
-
loop should be skipped before timing actually starts. Skipping
-
a few iterations can be useful if there are startup costs like
-
caching that are only paid early on, resulting in a skewed
-
average when those early iterations dominate the average time
-
per iteration.
-
-
NOTE: Using `start_ndx` to skip some iterations makes the time
-
spent performing those iterations not be included in the
-
displayed duration. Please account for this if you use the
-
displayed duration for anything formal.
-
-
This parameter defaults to `0`.
-
-
:param print_ndx: determines which loop interation that the timing
-
logging will start on. The intent is that we don't start
-
logging until we've given the loop a few iterations to let the
-
average time-per-iteration a chance to stablize a bit. We
-
require that `print_ndx` not be less than `start_ndx` times
-
`backoff`, since `start_ndx` greater than `0` implies that the
-
early N iterations are unstable from a timing perspective.
-
-
`print_ndx` defaults to `4`.
-
-
:param backoff: This is used to how many iterations to skip before
-
logging again. Frequent logging is less interesting later on,
-
so by default we double the gap between logging messages each
-
time after the first.
-
-
`backoff` defaults to `2` unless iter_len is > 1000, in which
-
case it defaults to `4`.
-
-
:param iter_len: Since we need to know the number of items to
-
estimate when the loop will finish, that can be provided by
-
passing in a value for `iter_len`. If a value isn't provided,
-
then it will be set by using the value of `len(iter)`.
-
-
:return:
-
"""
-
if iter_len
is
None:
-
iter_len =
len(
iter)
-
-
if backoff
is
None:
-
backoff =
2
-
while backoff **
7 < iter_len:
-
backoff *=
2
-
-
assert backoff >=
2
-
while print_ndx < start_ndx * backoff:
-
print_ndx *= backoff
-
-
log.warning(
"{} ----/{}, starting".
format(
-
desc_str,
-
iter_len,
-
))
-
start_ts = time.time()
-
for (current_ndx, item)
in
enumerate(
iter):
-
yield (current_ndx, item)
-
if current_ndx == print_ndx:
-
# ... <1> step1:计算若干隔数据集加载时间;step2:平均得到每个数据集加载时间;step3:乘以数据集长度得到预计加载所有数据的时间
-
duration_sec = ((time.time() - start_ts)
-
/ (current_ndx - start_ndx +
1)
-
* (iter_len-start_ndx)
-
)
-
-
done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
-
done_td = datetime.timedelta(seconds=duration_sec)
-
-
log.info(
"{} {:-4}/{}, done at {}, {}".
format(
-
desc_str,
-
current_ndx,
-
iter_len,
-
str(done_dt).rsplit(
'.',
1)[
0],
# 运行了current_ndx次后,预估的加载完整个数据集后的系统时间
-
str(done_td).rsplit(
'.',
1)[
0],
# 运行了current_ndx次后,预估的加载完整个数据集所需要的秒数
-
))
-
-
print_ndx *= backoff
-
-
if current_ndx +
1 == start_ndx:
-
start_ts = time.time()
-
-
log.warning(
"{} ----/{}, done at {}".
format(
-
desc_str,
-
iter_len,
-
str(datetime.datetime.now()).rsplit(
'.',
1)[
0],
-
))
3. prepcahe.py
这个脚本用来尝试加载整个数据集,测试加载数据集所需要的时间。核心时调用enumerateWithEstimate函数。
-
import argparse
# 参数解释器
-
import sys
-
-
import numpy
as np
-
-
import torch.nn
as nn
-
from torch.autograd
import Variable
-
from torch.optim
import SGD
-
from torch.utils.data
import DataLoader
-
-
from util.util
import enumerateWithEstimate
-
from .dsets
import LunaDataset
-
from util.logconf
import logging
-
from .model
import LunaModel
-
-
log = logging.getLogger(__name__)
-
# log.setLevel(logging.WARN)
-
log.setLevel(logging.INFO)
-
# log.setLevel(logging.DEBUG)
-
-
-
class
LunaPrepCacheApp:
-
@classmethod
-
def
__init__(
self, sys_argv=None):
-
if sys_argv
is
None:
-
sys_argv = sys.argv[
1:]
-
-
parser = argparse.ArgumentParser()
# 命令行参数修饰器
-
parser.add_argument(
'--batch-size',
# 添加参数
-
help=
'Batch size to use for training',
-
default=
1024,
-
type=
int,
-
)
-
parser.add_argument(
'--num-workers',
-
help=
'Number of worker processes for background data loading',
-
default=
8,
-
type=
int,
-
)
-
-
self.cli_args = parser.parse_args(sys_argv)
# 解释参数
-
-
def
main(
self):
-
log.info(
"Starting {}, {}".
format(
type(self).__name__, self.cli_args))
-
-
self.prep_dl = DataLoader(
-
LunaDataset(
-
sortby_str=
'series_uid',
-
),
-
batch_size=self.cli_args.batch_size,
-
num_workers=self.cli_args.num_workers,
-
)
-
-
batch_iter = enumerateWithEstimate(
# 尝试加载数据集,预估加载整个数据集所需时间
-
self.prep_dl,
-
"Stuffing cache",
-
start_ndx=self.prep_dl.num_workers,
-
)
-
for _
in batch_iter:
-
pass
-
-
-
if __name__ ==
'__main__':
-
LunaPrepCacheApp().main()
# 对类的__init__函数使用了@classmethod修饰器,所以可以不需要实例化,直接调用类内函数
在jupyter运行方法可参考原书代码的【p2_run_everything.ipynb】的【chapter11-cell2】。具体运行方法:
step1:加载相关库和函数
step2:使用命令行形式调用LunaPrepCacheApp函数。
运行结果:
从下图可见,数据集中一个551065个样本,每个batch有1024个样本,一共539个batch,加载16个batch后,推算出加载完所有batch的时间要6个小时05分。
4. training.py
注释了部分代码,其中部分关于tensorboard的代码注释放到第六篇文章的笔记。训练结果及代码如下:
-
import argparse
-
import datetime
-
import os
-
import sys
-
-
import numpy
as np
-
-
from torch.utils.tensorboard
import SummaryWriter
-
-
import torch
-
import torch.nn
as nn
-
from torch.optim
import SGD, Adam
-
from torch.utils.data
import DataLoader
-
-
from util.util
import enumerateWithEstimate
-
from .dsets
import LunaDataset
-
from util.logconf
import logging
-
from .model
import LunaModel
-
-
log = logging.getLogger(__name__)
-
# log.setLevel(logging.WARN)
-
log.setLevel(logging.INFO)
-
log.setLevel(logging.DEBUG)
-
-
# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
-
# 将每个样本在训练时候的label、预测值、loss存在了一个矩阵,用于打印结果和tensorboard上显示
-
# 矩阵第一行为label,第二行为预测值,第三行为loss值,每一列为一个样本
-
METRICS_LABEL_NDX=
0
# label的行索引
-
METRICS_PRED_NDX=
1
# 预测值行索引
-
METRICS_LOSS_NDX=
2
# loss值行索引
-
METRICS_SIZE =
3
# 矩阵行数
-
-
class
LunaTrainingApp:
-
def
__init__(
self, sys_argv=None):
-
if sys_argv
is
None:
-
sys_argv = sys.argv[
1:]
-
-
parser = argparse.ArgumentParser()
-
parser.add_argument(
'--num-workers',
-
help=
'Number of worker processes for background data loading',
-
default=
6,
# 使用的CPU核心数,我用的i5-12490f为6核
-
type=
int,
-
)
-
parser.add_argument(
'--batch-size',
-
help=
'Batch size to use for training',
-
default=
24,
# 每个batch样本数
-
type=
int,
-
)
-
parser.add_argument(
'--epochs',
-
help=
'Number of epochs to train for',
-
default=
1,
# 训练的代数
-
type=
int,
-
)
-
-
parser.add_argument(
'--tb-prefix',
-
default=
'p2ch11',
-
help=
"Data prefix to use for Tensorboard run. Defaults to chapter.",
-
)
-
-
parser.add_argument(
'comment',
-
help=
"Comment suffix for Tensorboard run.",
-
nargs=
'?',
-
default=
'dwlpt',
-
)
-
self.cli_args = parser.parse_args(sys_argv)
-
self.time_str = datetime.datetime.now().strftime(
'%Y-%m-%d_%H.%M.%S')
-
-
self.trn_writer =
None
-
self.val_writer =
None
-
self.totalTrainingSamples_count =
0
-
-
self.use_cuda = torch.cuda.is_available()
-
self.device = torch.device(
"cuda"
if self.use_cuda
else
"cpu")
-
-
self.model = self.initModel()
# 将模型搬到cuda
-
self.optimizer = self.initOptimizer()
# 定义优化器
-
-
def
initModel(
self):
-
model = LunaModel()
-
if self.use_cuda:
-
log.info(
"Using CUDA; {} devices.".
format(torch.cuda.device_count()))
-
if torch.cuda.device_count() >
1:
-
model = nn.DataParallel(model)
# 如果有多个gpu,分配多给GPU训练
-
model = model.to(self.device)
-
return model
-
-
def
initOptimizer(
self):
-
# 一般第一次训练用SGD看看效果,再选择其他优化器。比较常用参数为lr=0.001,momentum=0.99
-
return SGD(self.model.parameters(), lr=
0.001, momentum=
0.99)
-
# return Adam(self.model.parameters())
-
-
def
initTrainDl(
self):
-
# 由于LunaDataset的getCtRawCandidate被diskcache修饰,所以第一次加载数据集时,需要从文件读取数据,
-
# 同时数据处理后会缓存到磁盘,速度较慢;第二次开始,会直接从缓存加载,速度会较快。
-
train_ds = LunaDataset(
-
val_stride=
10,
-
isValSet_bool=
False,
-
)
-
-
batch_size = self.cli_args.batch_size
-
if self.use_cuda:
-
batch_size *= torch.cuda.device_count()
-
-
train_dl = DataLoader(
-
train_ds,
-
batch_size=batch_size,
-
num_workers=self.cli_args.num_workers,
-
pin_memory=self.use_cuda,
-
)
-
-
return train_dl
-
-
def
initValDl(
self):
-
val_ds = LunaDataset(
-
val_stride=
10,
-
isValSet_bool=
True,
-
)
-
-
batch_size = self.cli_args.batch_size
-
if self.use_cuda:
-
batch_size *= torch.cuda.device_count()
-
-
val_dl = DataLoader(
-
val_ds,
-
batch_size=batch_size,
-
num_workers=self.cli_args.num_workers,
-
pin_memory=self.use_cuda,
-
)
-
-
return val_dl
-
-
def
initTensorboardWriters(
self):
-
if self.trn_writer
is
None:
-
log_dir = os.path.join(
'runs', self.cli_args.tb_prefix, self.time_str)
-
-
self.trn_writer = SummaryWriter(
-
log_dir=log_dir +
'-trn_cls-' + self.cli_args.comment)
-
self.val_writer = SummaryWriter(
-
log_dir=log_dir +
'-val_cls-' + self.cli_args.comment)
-
-
-
def
main(
self):
-
log.info(
"Starting {}, {}".
format(
type(self).__name__, self.cli_args))
-
-
train_dl = self.initTrainDl()
-
val_dl = self.initValDl()
-
-
for epoch_ndx
in
range(
1, self.cli_args.epochs +
1):
-
-
log.info(
"Epoch {} of {}, {}/{} batches of size {}*{}".
format(
-
epoch_ndx,
-
self.cli_args.epochs,
-
len(train_dl),
-
len(val_dl),
-
self.cli_args.batch_size,
-
(torch.cuda.device_count()
if self.use_cuda
else
1),
-
))
-
-
trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
-
self.logMetrics(epoch_ndx,
'trn', trnMetrics_t)
-
-
valMetrics_t = self.doValidation(epoch_ndx, val_dl)
-
self.logMetrics(epoch_ndx,
'val', valMetrics_t)
-
-
if
hasattr(self,
'trn_writer'):
-
self.trn_writer.close()
-
self.val_writer.close()
-
-
-
def
doTraining(
self, epoch_ndx, train_dl):
-
self.model.train()
-
trnMetrics_g = torch.zeros(
-
METRICS_SIZE,
-
len(train_dl.dataset),
-
device=self.device,
-
)
-
-
# batch_iter = enumerateWithEstimate(
-
# train_dl,
-
# "E{} Training".format(epoch_ndx),
-
# start_ndx=train_dl.num_workers,
-
# )
-
for batch_ndx, batch_tup
in
enumerate(train_dl):
-
self.optimizer.zero_grad()
-
-
loss_var = self.computeBatchLoss(
-
batch_ndx,
-
batch_tup,
-
train_dl.batch_size,
-
trnMetrics_g
-
)
-
-
loss_var.backward()
-
self.optimizer.step()
-
-
# # This is for adding the model graph to TensorBoard.
-
# if epoch_ndx == 1 and batch_ndx == 0:
-
# with torch.no_grad():
-
# model = LunaModel()
-
# self.trn_writer.add_graph(model, batch_tup[0], verbose=True)
-
# self.trn_writer.close()
-
-
self.totalTrainingSamples_count +=
len(train_dl.dataset)
-
-
return trnMetrics_g.to(
'cpu')
-
-
-
def
doValidation(
self, epoch_ndx, val_dl):
-
with torch.no_grad():
-
self.model.
eval()
-
valMetrics_g = torch.zeros(
-
METRICS_SIZE,
-
len(val_dl.dataset),
-
device=self.device,
-
)
-
-
batch_iter = enumerateWithEstimate(
-
val_dl,
-
"E{} Validation ".
format(epoch_ndx),
-
start_ndx=val_dl.num_workers,
-
)
-
for batch_ndx, batch_tup
in batch_iter:
-
self.computeBatchLoss(
-
batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
-
-
return valMetrics_g.to(
'cpu')
-
-
-
-
def
computeBatchLoss(
self, batch_ndx, batch_tup, batch_size, metrics_g):
-
input_t, label_t, _series_list, _center_list = batch_tup
-
-
input_g = input_t.to(self.device, non_blocking=
True)
-
label_g = label_t.to(self.device, non_blocking=
True)
-
-
logits_g, probability_g = self.model(input_g)
-
-
loss_func = nn.CrossEntropyLoss(reduction=
'none')
# reduction=none时,将每个样本的loss返回
-
loss_g = loss_func(
-
logits_g,
-
label_g[:,
1],
-
)
-
start_ndx = batch_ndx * batch_size
-
end_ndx = start_ndx + label_t.size(
0)
-
-
# 将训练结果存到矩阵
-
metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = \
-
label_g[:,
1].detach()
-
metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = \
-
probability_g[:,
1].detach()
-
metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = \
-
loss_g.detach()
-
-
return loss_g.mean()
-
-
-
def
logMetrics(
-
self,
-
epoch_ndx,
-
mode_str,
-
metrics_t,
-
classificationThreshold=0.5,
-
):
-
self.initTensorboardWriters()
-
log.info(
"E{} {}".
format(
-
epoch_ndx,
-
type(self).__name__,
-
))
-
-
negLabel_mask = metrics_t[METRICS_LABEL_NDX] <= classificationThreshold
-
negPred_mask = metrics_t[METRICS_PRED_NDX] <= classificationThreshold
-
-
posLabel_mask = ~negLabel_mask
-
posPred_mask = ~negPred_mask
-
-
neg_count =
int(negLabel_mask.
sum())
-
pos_count =
int(posLabel_mask.
sum())
-
-
neg_correct =
int((negLabel_mask & negPred_mask).
sum())
-
pos_correct =
int((posLabel_mask & posPred_mask).
sum())
-
-
metrics_dict = {}
-
metrics_dict[
'loss/all'] = \
-
metrics_t[METRICS_LOSS_NDX].mean()
-
metrics_dict[
'loss/neg'] = \
-
metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
-
metrics_dict[
'loss/pos'] = \
-
metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
-
-
metrics_dict[
'correct/all'] = (pos_correct + neg_correct) \
-
/ np.float32(metrics_t.shape[
1]) *
100
-
metrics_dict[
'correct/neg'] = neg_correct / np.float32(neg_count) *
100
-
metrics_dict[
'correct/pos'] = pos_correct / np.float32(pos_count) *
100
-
-
log.info(
-
(
"E{} {:8} {loss/all:.4f} loss, "
-
+
"{correct/all:-5.1f}% correct, "
-
).
format(
-
epoch_ndx,
-
mode_str,
-
**metrics_dict,
-
)
-
)
-
log.info(
-
(
"E{} {:8} {loss/neg:.4f} loss, "
-
+
"{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
-
).
format(
-
epoch_ndx,
-
mode_str +
'_neg',
-
neg_correct=neg_correct,
-
neg_count=neg_count,
-
**metrics_dict,
-
)
-
)
-
log.info(
-
(
"E{} {:8} {loss/pos:.4f} loss, "
-
+
"{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
-
).
format(
-
epoch_ndx,
-
mode_str +
'_pos',
-
pos_correct=pos_correct,
-
pos_count=pos_count,
-
**metrics_dict,
-
)
-
)
-
-
writer =
getattr(self, mode_str +
'_writer')
-
-
for key, value
in metrics_dict.items():
-
writer.add_scalar(key, value, self.totalTrainingSamples_count)
-
-
writer.add_pr_curve(
-
'pr',
-
metrics_t[METRICS_LABEL_NDX],
-
metrics_t[METRICS_PRED_NDX],
-
self.totalTrainingSamples_count,
-
)
-
-
bins = [x/
50.0
for x
in
range(
51)]
-
-
negHist_mask = negLabel_mask & (metrics_t[METRICS_PRED_NDX] >
0.01)
-
posHist_mask = posLabel_mask & (metrics_t[METRICS_PRED_NDX] <
0.99)
-
-
if negHist_mask.
any():
-
writer.add_histogram(
-
'is_neg',
-
metrics_t[METRICS_PRED_NDX, negHist_mask],
-
self.totalTrainingSamples_count,
-
bins=bins,
-
)
-
if posHist_mask.
any():
-
writer.add_histogram(
-
'is_pos',
-
metrics_t[METRICS_PRED_NDX, posHist_mask],
-
self.totalTrainingSamples_count,
-
bins=bins,
-
)
-
-
# score = 1 \
-
# + metrics_dict['pr/f1_score'] \
-
# - metrics_dict['loss/mal'] * 0.01 \
-
# - metrics_dict['loss/all'] * 0.0001
-
#
-
# return score
-
-
# def logModelMetrics(self, model):
-
# writer = getattr(self, 'trn_writer')
-
#
-
# model = getattr(model, 'module', model)
-
#
-
# for name, param in model.named_parameters():
-
# if param.requires_grad:
-
# min_data = float(param.data.min())
-
# max_data = float(param.data.max())
-
# max_extent = max(abs(min_data), abs(max_data))
-
#
-
# # bins = [x/50*max_extent for x in range(-50, 51)]
-
#
-
# try:
-
# writer.add_histogram(
-
# name.rsplit('.', 1)[-1] + '/' + name,
-
# param.data.cpu().numpy(),
-
# # metrics_a[METRICS_PRED_NDX, negHist_mask],
-
# self.totalTrainingSamples_count,
-
# # bins=bins,
-
# )
-
# except Exception as e:
-
# log.error([min_data, max_data])
-
# raise
-
-
-
if __name__ ==
'__main__':
-
LunaTrainingApp().main()
转载:https://blog.csdn.net/wxyczhyza/article/details/127684935