飞道的博客

(学生快速上手向)python图片分类识别器

299人阅读  评论(0)

本文着重讲不学无术的大学生如何快速上手跑出结果。本项目基于resnet34识别四类示意图,由cat vs dog项目改写而来。文末会说明如何快速把它改成你想要的项目(图片二分类等)。


项目代码、数据集下载:ht删tps://p除an.bai中du.c文om/s/1F打aI6hKNPB_0w_oed9H开0STg 提取码: z5v5


1.各文件/文件夹作用

 

自上到下:

checkpoints  储存每个epoch训练后的模型

datasets  储存训练集、测试集

image  用来给数据集做重命名,后面会提到

result  似乎没用过?

图片分类结果  手动分类的数据集。将示意图分四类,每类约150张

config  储存模型相关参数完全不用修改

dataset  数据集预处理等工作。

rename  数据集图片重命名用,后面会讲

test_model是从checkpoints里取出来训练好的模型改个名,文件夹里是我们的模型

test  测试程序,train  训练程序。


2.如何运行项目

先自己看import哪些库,装好库

①图片重命名

我使用的数据集存在图片分类结果文件夹了,你也可以不用它。

把分类好的四类图片中任一类(如sketch1)全部放入image/raw。

将rename.py中的label = 'sketch4'改成label = 'sketch1'

index_list = [i for i in range(52, imgs_num + 52)]也要根据图片数量做调整相信废物大学生也能看得懂

运行rename.py会在image/processed生成重命名好的图片。格式为sktech1.0.jpg、sktech1.1.jpg、sktech1.2.jpg等。将这些图片二八分开分别放入datasets/test和datasets/train

四类图片都要这样处理。

需要注意的是,最后无论是test文件夹还是train文件夹,图片的id不能重复,比如sktech1.0.jpg里0就是id。不能同时存在sktech1.0.jpg​​​​​​​和sktech2.0.jpg 。

②运行train.py训练模型。

此时checkpoints文件夹里会多出来很多模型,同时shell会输出正确率。当你认为正确率够高就可以停了,从checkpoints拿出最新的模型改名为test_model,拿到主目录替换我们的模型。

③运行test.py输出正确率。

此时项目运行完成。


3.Q&A

老师的要求是分类其他类型的图片,不是你给的示意图。怎么办?

答:用你自己的数据集即可。不知道怎么找数据集可以评论区问。

老师的要求是图片的二/三分类,怎么修改代码?

答:以二分类为例。修改以下代码:

datasets.py:第60行

 从四类改两类。

rename.py:重命名图片跟着上面步骤做。

test_modification.py:

29行的model.fc = nn.Linear(512, 4)   把4改成2.

48行(下图)改2类

72行同理:

train.py:

 30行model.fc = nn.Linear(512,4)   把4改成2

110行confusion_matrix = meter.ConfusionMeter(4)  把4改2

120行accuracy = 100.* (cm_value[0][0] + cm_value[1][1] + cm_value[2][2] + cm_value[3][3]) / (cm_value.sum())    把cm_value[2][2] + cm_value[3][3])删掉,只留两类。

应该就这些,改不好来评论区问。

③你这项目没做可视化啊?

答:确实。


本文结束


以下代码无关本文,仅充数用


  
  1. # coding=utf-8
  2. """ test
  3. 使用测试集测试模型结果
  4. """
  5. from config import _setting_
  6. import os
  7. import torch as t
  8. from dataset import NatureSketchClassification
  9. from torch.utils.data import DataLoader
  10. from torchnet import meter
  11. from torch.autograd import Variable
  12. from torchvision import models
  13. from torch import nn
  14. import time
  15. import csv
  16. """"""
  17. def test( **kwargs):
  18. # set data
  19. test_data = NatureSketchClassification(_setting_.test_data_root, test= True)
  20. test_dataloader = DataLoader(test_data, batch_size=_setting_.batch_size, shuffle= False, num_workers=_setting_.num_workers)
  21. results = []
  22. # set model
  23. model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
  24. model.fc = nn.Linear( 512, 4)
  25. model.load_state_dict(t.load( './test_model.pth', map_location= 'cpu'))
  26. model. eval()
  27. for id, (data, path) in enumerate(test_dataloader):
  28. # input = Variable(data,volatile=True)
  29. with t.no_grad():
  30. input = Variable(data)
  31. score = model( input)
  32. print( 'score=',score) #检验score
  33. path = path.numpy().tolist()
  34. _,predicted = t. max(score.data, 1)
  35. #Modification
  36. predicted = predicted.data.cpu().numpy().tolist()
  37. res = ""
  38. print( 'predicted=',predicted) #检验predicted
  39. #Modification
  40. for (i, j) in zip(path, predicted):
  41. if j == 0:
  42. res = "sketch1"
  43. elif j == 1:
  44. res = "sketch2"
  45. elif j == 2:
  46. res = "sketch3"
  47. elif j == 3:
  48. res = "sketch4"
  49. print( 'res=',res) #检验res(result)
  50. results.append([i, "".join(res)])
  51. res = []
  52. truth = ""
  53. compare = ""
  54. imgs = [os.path.join(_setting_.test_data_root,img) for img in os.listdir(_setting_.test_data_root)] #获取root路径下所有图片的地址
  55. imgs_num = len(imgs) # 图片数量
  56. NumofCorrect = 0
  57. imgs = sorted(imgs,key= lambda x: int(x.split( '.')[- 2].split( '/')[- 1])) # 按序号排序
  58. for image in imgs:
  59. id = int(image.split( '.')[- 2].split( '/')[- 1]) # 获取id
  60. #Modification
  61. if 'sketch1' in image.split( '/')[- 1]:
  62. truth = 'sketch1'
  63. elif 'sketch2' in image.split( '/')[- 1]:
  64. truth = 'sketch2'
  65. elif 'sketch3' in image.split( '/')[- 1]:
  66. truth = 'sketch3'
  67. else:
  68. truth = 'sketch4'
  69. print( 'truth=',truth)
  70. #truth = 'nature' if 'nature' in image.split('/')[-1] else 'sketch' # 获取图片的真实分类
  71. compare = 'true' if truth == results[ id - 1][ 1] else 'false'
  72. if compare == 'true':
  73. NumofCorrect = NumofCorrect + 1
  74. res.append([results[ id - 1][ 0], results[ id - 1][ 1], "".join(truth), compare])
  75. Accuracy = NumofCorrect / imgs_num * 100
  76. round(Accuracy, 2)
  77. write_csv(res, _setting_.result_file, Accuracy)
  78. for id, label, truth, compare in res:
  79. if compare == 'false':
  80. print( "number: "+ str( id) + ", res: " + label + ", truth: " + truth + ", IsCorrect: " + compare)
  81. print( "Accuracy: " + str(Accuracy))
  82. return results
  83. """"""
  84. def write_csv( results, file_name, acc):
  85. Accuracy = []
  86. Accuracy.append([ " ", "Accuracy", "".join( str(acc))])
  87. with open(file_name, "w") as f:
  88. writer = csv.writer(f)
  89. writer.writerow([ 'id', 'label', 'truth', 'IsCorrect'])
  90. writer.writerows(results)
  91. writer.writerows(Accuracy)
  92. if __name__ == '__main__':
  93. test()


转载:https://blog.csdn.net/m0_66480474/article/details/127592119
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场