小言_互联网的博客

分类网络知识蒸馏【附代码】

384人阅读  评论(0)

知识蒸馏属于模型的压缩一种方法,但其实这种方法又属于一种伪压缩,是将一个性能较好的teacher网络“压缩”进一个性能较差的student网络中,或者是可类似于在teacher的指导下让student进行学习进而提高性能。

知识蒸馏是一种思想,并不像其他压缩方法有现成的库,因此对于实际需求与场景需要自己去实现。蒸馏也分为“离线”蒸馏与“在线”蒸馏。前者是建立T-S进行KD训练,而后者可以说是一种自学习,让student自己做自己的teacher。

同时蒸馏还分为逻辑蒸馏和特征蒸馏,前者是在两个网络最终输出部分建立loss关系,而后者是在网络中间的某些特征部分建立loss进行蒸馏。

本文是以手写数字为例,teacher选用的resnet18,student选用的resnet50【大家可能会想resnet50比resnet18强啊,为啥resnet50是student,这是因为我在实际测试的时候发现在手写数字这个数据上resnet18的准确率比resnet50高,猜测是因为在低分辨率下resnet50虽然loss在下降,但由于网络较深,特征丢失也明显,网络退化较明显】。当然这里你也可以尝试resnet做teacher,mobilnet做student【我这样训练了一下发现对mobilnet提升变化不大】


注:这里不做模型和蒸馏改进,仅仅是给大家展示一下效果,至于更细化的蒸馏有兴趣的可以自己去研究。【有关目标检测方面的KD 训练,我将会在明年以后推出】


目录

teacher train代码

student未KD 训练 

KD train代码

KD_loss代码:

完整代码


 

teacher train代码

参数说明:

teacher_model:选用的teacher网络

train_loader:训练集

test_loader:测试集

loss_func:损失函数

epochs:训练迭代数


  
  1. def teacher_train( teacher_model, train_loader, test_loader, loss_func, epochs):
  2. teacher_model.train()
  3. teacher_model.cuda()
  4. # train
  5. for i in range(epochs):
  6. for data, label in train_loader:
  7. data = data.to(device)
  8. label = label.to(device)
  9. output = teacher_model(data)
  10. loss = loss_func(output, label)
  11. optimizer_teacher.zero_grad()
  12. loss.backward()
  13. optimizer_teacher.step()
  14. print( "loss: ", loss)
  15. # eval
  16. correct = 0
  17. teacher_model. eval()
  18. teacher_model.cuda()
  19. for test_data, test_label in test_loader:
  20. test_data = test_data.to(device)
  21. test_label = test_label.to(device)
  22. with torch.no_grad():
  23. output = teacher_model(test_data)
  24. # acc = torch.mean((torch.argmax(F.softmax(output, dim=-1), dim=-1) == test_label).type(torch.FloatTensor))
  25. # print("teacher acc: ", acc)
  26. _, pred = torch. max(output, dim= 1)
  27. correct += float(torch. sum(pred == test_label))
  28. print( 'test_acc:{}'. format(correct / len(test_dataset)))
  29. return teacher_model

训练结果(我只训练了5轮): 


  
  1. teacher model train
  2. loss: tensor(0.0891, device= 'cuda:0', grad_fn=<NllLossBackward>)
  3. test_acc:0.9845
  4. loss: tensor(0.0132, device= 'cuda:0', grad_fn=<NllLossBackward>)
  5. test_acc:0.9865
  6. loss: tensor(0.0019, device= 'cuda:0', grad_fn=<NllLossBackward>)
  7. test_acc:0.9909
  8. loss: tensor(0.0042, device= 'cuda:0', grad_fn=<NllLossBackward>)
  9. test_acc:0.9909
  10. loss: tensor(0.0034, device= 'cuda:0', grad_fn=<NllLossBackward>)
  11. test_acc:0.9917
  12. teacher model trained finished!

student未KD 训练 

参数说明:

student_model:选用的student网络

train_loader:训练集

test_loader:测试集

loss_func:损失函数

epochs:训练迭代数


  
  1. def student_train( student_model, train_loader, test_loader, loss_func, epochs):
  2. student_model.train()
  3. student_model.cuda()
  4. # train
  5. for i in range(epochs):
  6. for data, label in train_loader:
  7. data = data.to(device)
  8. label = label.to(device)
  9. output = student_model(data)
  10. loss = loss_func(output, label)
  11. optimizer_student.zero_grad()
  12. loss.backward()
  13. optimizer_student.step()
  14. print( "student loss: ", loss)
  15. # eval
  16. correct = 0
  17. student_model. eval()
  18. student_model.cuda()
  19. for test_data, test_label in test_loader:
  20. test_data = test_data.to(device)
  21. test_label = test_label.to(device)
  22. with torch.no_grad():
  23. output = student_model(test_data)
  24. # acc = torch.mean((torch.argmax(F.softmax(output, dim=-1), dim=-1) == test_label).type(torch.FloatTensor))
  25. # print("teacher acc: ", acc)
  26. _, pred = torch. max(output, dim= 1)
  27. correct += float(torch. sum(pred == test_label))
  28. print( 'student test_acc:{}'. format(correct / len(test_dataset)))

没有KD train的效果如下: 


  
  1. student model ready train
  2. student loss: tensor(0.1876, device= 'cuda:0', grad_fn=<NllLossBackward>)
  3. student test_acc:0.9588
  4. student loss: tensor(0.0219, device= 'cuda:0', grad_fn=<NllLossBackward>)
  5. student test_acc:0.9737
  6. student loss: tensor(0.0588, device= 'cuda:0', grad_fn=<NllLossBackward>)
  7. student test_acc:0.9812
  8. student loss: tensor(0.0024, device= 'cuda:0', grad_fn=<NllLossBackward>)
  9. student test_acc:0.9853
  10. student loss: tensor(0.0022, device= 'cuda:0', grad_fn=<NllLossBackward>)
  11. student test_acc:0.9814
  12. student model trained finished!

KD train代码

参数说明:

teacher_model:为已经训练好的teacher

student_model:待KD的student网络

train_loader:训练集

test_loader:测试集


  
  1. def KD_train( teacher_model, student_model, train_loader, test_loader,loss_func, epochs):
  2. teacher_model. eval()
  3. student_model.train()
  4. student_model.cuda()
  5. HL = nn.CrossEntropyLoss()
  6. for i in range(epochs):
  7. for data, labels in train_loader:
  8. data = data.to(device)
  9. labels = labels.to(device)
  10. teacher_output = teacher_model(data)
  11. student_output = student_model(data)
  12. soft_loss = KD_loss(teacher_output, student_output)
  13. hard_loss = HL(student_output, labels)
  14. loss = hard_loss + alpha*soft_loss
  15. optimizer_student.zero_grad()
  16. loss.backward()
  17. optimizer_student.step()
  18. print( "KD loss: ", loss)
  19. student_model. eval()
  20. ACC = 0
  21. for data, labels in test_loader:
  22. with torch.no_grad():
  23. data = data.to(device)
  24. labels = labels.to(device)
  25. output = student_model(data)
  26. _, pred = torch. max(output, dim= 1)
  27. ACC += float(torch. sum(pred == labels))
  28. print( 'KD test_acc:{}'. format(ACC / len(test_dataset)))

代码中的teacher_output是teacher网络的输出,student_output是student的输出,两者之间设计的KD_loss代码如下:

KD_loss代码:

Temp为温度系数,默认为2【可以根据自己的数据集去尝试】

alpha是hard与soft的平衡系数【默认0.5,也是根据自己的实际情况调整】

损失函数采用的KL,你也可以改为交叉熵。


  
  1. Temp = 2. # 温度常数
  2. alpha = 0.5
  3. def KD_loss( p, q): # p指的老师老师的预测(经过softmax),q是学生的预测
  4. pt = F.softmax(p / Temp, dim= 1)
  5. ps = F.log_softmax(q / Temp, dim= 1)
  6. return nn.KLDivLoss(reduction= 'mean')(ps, pt) * (Temp** 2)

KD tran后student结果: 


  
  1. KD loss: tensor(0.2580, device= 'cuda:0', grad_fn=<AddBackward0>)
  2. KD test_acc:0.9753
  3. KD loss: tensor(0.1686, device= 'cuda:0', grad_fn=<AddBackward0>)
  4. KD test_acc:0.9748
  5. KD loss: tensor(0.0827, device= 'cuda:0', grad_fn=<AddBackward0>)
  6. KD test_acc:0.9849
  7. KD loss: tensor(0.0098, device= 'cuda:0', grad_fn=<AddBackward0>)
  8. KD test_acc:0.9865
  9. KD loss: tensor(0.0114, device= 'cuda:0', grad_fn=<AddBackward0>)
  10. KD test_acc:0.988

 


 可以看出经过KD训练后student略有提升【主要手写数字这个太容易训练,稍微一训练就可以有较高的准确率】,如果换成别的数据集【比如猫狗数据集可能会明显点,可以自己试试】。

如果要换teacher和student网络,只需要在代码中将teacher_model和student_model网络进行替换即可。

完整代码

目标检测方面的KD比较麻烦,这个以后再讲。


  
  1. import torch
  2. from torch.optim import Adam, SGD
  3. import torch.nn.functional as F
  4. import torch.nn as nn
  5. from torchvision.models import resnet50, resnet34, resnet18, MobileNetV2
  6. import torchvision
  7. import torchvision.transforms as transforms
  8. Temp = 2. # 温度常数
  9. alpha = 0.5
  10. def KD_loss( p, q): # p指的老师老师的预测(经过softmax),q是学生的预测
  11. pt = F.softmax(p / Temp, dim= 1)
  12. ps = F.log_softmax(q / Temp, dim= 1)
  13. return nn.KLDivLoss(reduction= 'mean')(ps, pt) * (Temp** 2)
  14. def teacher_train( teacher_model, train_loader, test_loader, loss_func, epochs):
  15. teacher_model.train()
  16. teacher_model.cuda()
  17. # train
  18. for i in range(epochs):
  19. for data, label in train_loader:
  20. data = data.to(device)
  21. label = label.to(device)
  22. output = teacher_model(data)
  23. loss = loss_func(output, label)
  24. optimizer_teacher.zero_grad()
  25. loss.backward()
  26. optimizer_teacher.step()
  27. print( "loss: ", loss)
  28. # eval
  29. correct = 0
  30. teacher_model. eval()
  31. teacher_model.cuda()
  32. for test_data, test_label in test_loader:
  33. test_data = test_data.to(device)
  34. test_label = test_label.to(device)
  35. with torch.no_grad():
  36. output = teacher_model(test_data)
  37. # acc = torch.mean((torch.argmax(F.softmax(output, dim=-1), dim=-1) == test_label).type(torch.FloatTensor))
  38. # print("teacher acc: ", acc)
  39. _, pred = torch. max(output, dim= 1)
  40. correct += float(torch. sum(pred == test_label))
  41. print( 'test_acc:{}'. format(correct / len(test_dataset)))
  42. return teacher_model
  43. def student_train( student_model, train_loader, test_loader, loss_func, epochs):
  44. student_model.train()
  45. student_model.cuda()
  46. # train
  47. for i in range(epochs):
  48. for data, label in train_loader:
  49. data = data.to(device)
  50. label = label.to(device)
  51. output = student_model(data)
  52. loss = loss_func(output, label)
  53. optimizer_student.zero_grad()
  54. loss.backward()
  55. optimizer_student.step()
  56. print( "student loss: ", loss)
  57. # eval
  58. correct = 0
  59. student_model. eval()
  60. student_model.cuda()
  61. for test_data, test_label in test_loader:
  62. test_data = test_data.to(device)
  63. test_label = test_label.to(device)
  64. with torch.no_grad():
  65. output = student_model(test_data)
  66. # acc = torch.mean((torch.argmax(F.softmax(output, dim=-1), dim=-1) == test_label).type(torch.FloatTensor))
  67. # print("teacher acc: ", acc)
  68. _, pred = torch. max(output, dim= 1)
  69. correct += float(torch. sum(pred == test_label))
  70. print( 'student test_acc:{}'. format(correct / len(test_dataset)))
  71. def KD_train( teacher_model, student_model, train_loader, test_loader,loss_func, epochs):
  72. teacher_model. eval()
  73. student_model.train()
  74. student_model.cuda()
  75. HL = nn.CrossEntropyLoss()
  76. for i in range(epochs):
  77. for data, labels in train_loader:
  78. data = data.to(device)
  79. labels = labels.to(device)
  80. teacher_output = teacher_model(data)
  81. student_output = student_model(data)
  82. soft_loss = KD_loss(teacher_output, student_output)
  83. hard_loss = HL(student_output, labels)
  84. loss = hard_loss + alpha*soft_loss
  85. optimizer_student.zero_grad()
  86. loss.backward()
  87. optimizer_student.step()
  88. print( "KD loss: ", loss)
  89. student_model. eval()
  90. ACC = 0
  91. for data, labels in test_loader:
  92. with torch.no_grad():
  93. data = data.to(device)
  94. labels = labels.to(device)
  95. output = student_model(data)
  96. _, pred = torch. max(output, dim= 1)
  97. ACC += float(torch. sum(pred == labels))
  98. print( 'KD test_acc:{}'. format(ACC / len(test_dataset)))
  99. def do_train( teacher_model, student_model, train_loader, test_loader, loss_func, epochs):
  100. #教师训练
  101. teacher_model.train()
  102. teacher_model.to(device)
  103. print( "teacher model train")
  104. Teacher = teacher_train(teacher_model, train_loader, test_loader, loss_func, epochs)
  105. print( "teacher model trained finished!")
  106. # print("\n student model ready train")
  107. # student_train(student_model, train_loader, test_loader, loss_func, epochs)
  108. # print("\n student model trained finished!")
  109. print( "\n KD model ready train")
  110. KD_train(Teacher, student_model, train_loader, test_loader, loss_func, epochs)
  111. if __name__== "__main__":
  112. # 准备数据集
  113. batch_size = 64
  114. train_dataset = torchvision.datasets.MNIST( './data/', train= True, download= True,
  115. transform=transforms.Compose([
  116. transforms.Resize( 28),
  117. transforms.ToTensor(),
  118. transforms.Lambda( lambda x: x.repeat( 3, 1, 1)),
  119. transforms.Normalize(mean=( 0.5, 0.5, 0.5), std=( 0.5, 0.5, 0.5)),
  120. transforms.Grayscale(num_output_channels= 3)
  121. ])
  122. )
  123. test_dataset = torchvision.datasets.MNIST( './data/', train= False, download= True,
  124. transform=transforms.Compose([
  125. transforms.Resize( 28), # resnet默认图片输入大小224*224
  126. transforms.ToTensor(),
  127. transforms.Lambda( lambda x: x.repeat( 3, 1, 1)),
  128. transforms.Normalize(mean=( 0.5, 0.5, 0.5), std=( 0.5, 0.5, 0.5)),
  129. transforms.Grayscale(num_output_channels= 3)
  130. ])
  131. )
  132. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle= True)
  133. test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle= False)
  134. sample, label = next( iter(train_loader))
  135. print(sample.shape)
  136. print( "当前类: ", label)
  137. num_classes = 10
  138. lr = 0.01
  139. epochs = 5
  140. device = torch.device( 'cuda:0')
  141. teacher_model = resnet18(num_classes=num_classes)
  142. student_model = resnet50(num_classes=num_classes)
  143. optimizer_teacher = SGD(teacher_model.parameters(), lr=lr, momentum= 0.9)
  144. optimizer_student = SGD(student_model.parameters(), lr=lr, momentum= 0.9)
  145. loss_function = nn.CrossEntropyLoss()
  146. do_train(teacher_model, student_model, train_loader, test_loader, loss_function, epochs)

 


转载:https://blog.csdn.net/z240626191s/article/details/128154154
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场