小言_互联网的博客

【pointNet】基于pointNet的三维点云目标分类识别matlab仿真

431人阅读  评论(0)

1.软件版本

matlab2021a

2.系统概述

这里,采用的pointnet网络结构如下图所示:

        在整体网络结构中,

      首先进行set abstraction,这一部分主要即对点云中的点进行局部划分,提取整体特征,如图可见,在set abstraction中,主要有Sampling layer、Grouping layer、以及PointNet layer三层构成,sampling layer即完成提取中心点工作,采用fps算法,而在grouping中,即完成group操作,采用mrg或msg方法,最后对于提取出得点,使用pointnet进行特征提取。在msg中,第一层set abstraction取中心点512个,半径分别为0.1、0.2、0.4,每个圈内的最大点数为16,32,128。

Sampling layer

采样层在输入点云中选择一系列点,由此定义出局部区域的中心。采样算法使用迭代最远点采样方法 iterative farthest point sampling(FPS)。先随机选择一个点,然后再选择离这个点最远的点作为起点,再继续迭代,直到选出需要的个数为止相比随机采样,能更完整得通过区域中心点采样到全局点云

Grouping layer

目的是要构建局部区域,进而提取特征。思想就是利用临近点,并且论文中使用的是neighborhood ball,而不是KNN,是因为可以保证有一个fixed region scale,主要的指标还是距离distance。

Pointnet layer

在如何对点云进行局部特征提取的问题上,利用原有的Pointnet就可以很好的提取点云的特征,由此在Pointnet++中,原先的Pointnet网络就成为了Pointnet++网络中的子网络,层级迭代提取特征。

3.部分核心程序


  
  1. clc;
  2. clear;
  3. close all;
  4. warning off;
  5. addpath(genpath(pwd));
  6. rng( 'default')
  7. %****************************************************************************
  8. %更多关于matlab和fpga的搜索“fpga和matlab”的CSDN博客:
  9. %matlab/FPGA项目开发合作
  10. %https://blog.csdn.net/ccsss22?type=blog
  11. %****************************************************************************
  12. dsTrain = PtCloudClassificationDatastore( 'train');
  13. dsVal = PtCloudClassificationDatastore( 'test');
  14. ptCloud = pcread( 'Chair.ply');
  15. label = 'Chair';
  16. figure;pcshow(ptCloud)
  17. xlabel( "X");ylabel( "Y");zlabel( "Z");title(label)
  18. dsLabelCounts = transform(dsTrain,@(data){data{ 2} data{ 1}.Count});
  19. labelCounts = readall(dsLabelCounts);
  20. labels = vertcat(labelCounts{:, 1});
  21. counts = vertcat(labelCounts{:, 2});
  22. figure;histogram(labels);title( 'class distribution')
  23. rng( 0)
  24. [G,classes] = findgroups(labels);
  25. numObservations = splitapply(@numel,labels,G);
  26. desiredNumObservationsPerClass = max(numObservations);
  27. filesOverSample=[];
  28. for i= 1:numel(classes)
  29. if i== 1
  30. targetFiles = {dsTrain.Files{ 1:numObservations(i)}};
  31. else
  32. targetFiles = {dsTrain.Files{numObservations(i- 1)+ 1:sum(numObservations( 1:i))}};
  33. end
  34. % Randomly replicate the point clouds belonging to the infrequent classes
  35. files = randReplicateFiles(targetFiles,desiredNumObservationsPerClass);
  36. filesOverSample = vertcat(filesOverSample,files');
  37. end
  38. dsTrain.Files=filesOverSample;
  39. dsTrain.Files = dsTrain.Files(randperm(length(dsTrain.Files)));
  40. dsTrain.MiniBatchSize = 32;
  41. dsVal.MiniBatchSize = dsTrain.MiniBatchSize;
  42. dsTrain = transform(dsTrain,@augmentPointCloud);
  43. data = preview(dsTrain);
  44. ptCloud = data{ 1, 1};
  45. label = data{ 1, 2};
  46. figure;pcshow(ptCloud.Location,[ 0 0 1], "MarkerSize", 40, "VerticalAxisDir", "down")
  47. xlabel( "X");ylabel( "Y");zlabel( "Z");title(label)
  48. minPointCount = splitapply(@min,counts,G);
  49. maxPointCount = splitapply(@max,counts,G);
  50. meanPointCount = splitapply(@(x)round(mean(x)),counts,G);
  51. stats = table(classes,numObservations,minPointCount,maxPointCount,meanPointCount)
  52. numPoints = 1000;
  53. dsTrain = transform(dsTrain,@(data)selectPoints(data,numPoints));
  54. dsVal = transform(dsVal,@(data)selectPoints(data,numPoints));
  55. dsTrain = transform(dsTrain,@preprocessPointCloud);
  56. dsVal = transform(dsVal,@preprocessPointCloud);
  57. data = preview(dsTrain);
  58. figure;pcshow(data{ 1, 1},[ 0 0 1], "MarkerSize", 40, "VerticalAxisDir", "down");
  59. xlabel( "X");ylabel( "Y");zlabel( "Z");title(data{ 1, 2})
  60. inputChannelSize = 3;
  61. hiddenChannelSize1 = [ 64, 128];
  62. hiddenChannelSize2 = 256;
  63. [parameters.InputTransform, state.InputTransform] = initializeTransform(inputChannelSize,hiddenChannelSize1,hiddenChannelSize2);
  64. inputChannelSize = 3;
  65. hiddenChannelSize = [ 64 64];
  66. [parameters.SharedMLP1,state.SharedMLP1] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);
  67. inputChannelSize = 64;
  68. hiddenChannelSize1 = [ 64, 128];
  69. hiddenChannelSize2 = 256;
  70. [parameters.FeatureTransform, state.FeatureTransform] = initializeTransform(inputChannelSize,hiddenChannelSize,hiddenChannelSize2);
  71. inputChannelSize = 64;
  72. hiddenChannelSize = 64;
  73. [parameters.SharedMLP2,state.SharedMLP2] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);
  74. inputChannelSize = 64;
  75. hiddenChannelSize = [ 512, 256];
  76. numClasses = numel(classes);
  77. [parameters.ClassificationMLP, state.ClassificationMLP] = initializeClassificationMLP(inputChannelSize,hiddenChannelSize,numClasses);
  78. numEpochs = 60;
  79. learnRate = 0.001;
  80. l2Regularization = 0.1;
  81. learnRateDropPeriod = 15;
  82. learnRateDropFactor = 0.5;
  83. gradientDecayFactor = 0.9;
  84. squaredGradientDecayFactor = 0.999;
  85. avgGradients = [];
  86. avgSquaredGradients = [];
  87. [lossPlotter, trainAccPlotter,valAccPlotter] = initializeTrainingProgressPlot;
  88. % Number of classes
  89. numClasses = numel(classes);
  90. % Initialize the iterations
  91. iteration = 0;
  92. % To calculate the time for training
  93. start = tic;
  94. % Loop over the epochs
  95. for epoch = 1:numEpochs
  96. % Reset training and validation datastores.
  97. reset(dsTrain);
  98. reset(dsVal);
  99. % Iterate through data set.
  100. while hasdata(dsTrain) % if no data to read, exit the loop to start the next epoch
  101. iteration = iteration + 1;
  102. % Read data.
  103. data = read(dsTrain);
  104. % Create batch.
  105. [XTrain,YTrain] = batchData(data,classes);
  106. % Evaluate the model gradients and loss using dlfeval and the
  107. % modelGradients function.
  108. [gradients, loss, state, acc] = dlfeval(@modelGradients,XTrain,YTrain,parameters,state);
  109. % L2 regularization.
  110. gradients = dlupdate(@(g,p) g + l2Regularization*p,gradients,parameters);
  111. % Update the network parameters using the Adam optimizer.
  112. [parameters, avgGradients, avgSquaredGradients] = adamupdate(parameters, gradients, ...
  113. avgGradients, avgSquaredGradients, iteration,learnRate,gradientDecayFactor, squaredGradientDecayFactor);
  114. % Update the training progress.
  115. D = duration( 0, 0,toc(start), "Format", "hh:mm:ss");
  116. title(lossPlotter.Parent, "Epoch: " + epoch + ", Elapsed: " + string(D))
  117. addpoints(lossPlotter,iteration,double(gather(extractdata(loss))))
  118. addpoints(trainAccPlotter,iteration,acc);
  119. drawnow
  120. end
  121. % Create confusion matrix
  122. cmat = sparse(numClasses,numClasses);
  123. % Classify the validation data to monitor the tranining process
  124. while hasdata(dsVal)
  125. data = read(dsVal); % Get the next batch of data.
  126. [XVal,YVal] = batchData(data,classes); % Create batch.
  127. % Compute label predictions.
  128. isTrainingVal = 0; %Set at zero for validation data
  129. YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
  130. % Choose prediction with highest score as the class label for
  131. % XTest.
  132. [~,YValLabel] = max(YVal,[], 1);
  133. [~,YPredLabel] = max(YPred,[], 1);
  134. cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel); % Update the confusion matrix
  135. end
  136. % Update training progress plot with average classification accuracy.
  137. acc = sum(diag(cmat))./sum(cmat, "all");
  138. addpoints(valAccPlotter,iteration,acc);
  139. % Update the learning rate
  140. if mod(epoch,learnRateDropPeriod) == 0
  141. learnRate = learnRate * learnRateDropFactor;
  142. end
  143. reset(dsTrain); % Reset the training data since all the training data were already read
  144. % Shuffle the data at every epoch
  145. dsTrain.UnderlyingDatastore.Files = dsTrain.UnderlyingDatastore.Files(randperm(length(dsTrain.UnderlyingDatastore.Files)));
  146. reset(dsVal);
  147. end
  148. cmat = sparse(numClasses,numClasses); % Prepare sparse-double variable to do like zeros(2,2)
  149. reset(dsVal); % Reset the validation data
  150. data = readall(dsVal); % Read all validation data
  151. [XVal,YVal] = batchData(data,classes); % Create batch.
  152. % Classify the validation data using the helper function pointnetClassifier
  153. YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
  154. % Choose prediction with highest score as the class label for
  155. % XTest.
  156. [~,YValLabel] = max(YVal,[], 1);
  157. [~,YPredLabel] = max(YPred,[], 1);
  158. % Collect confusion metrics.
  159. cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);
  160. figure;chart = confusionchart(cmat,classes);
  161. acc = sum(diag(cmat))./sum(cmat, "all")

4.仿真结论

 

 

 

 

 5.参考文献

 [1][1] Qi C R ,  Su H ,  Mo K , et al. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation[C]// 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2017.资源同名下载


转载:https://blog.csdn.net/ccsss22/article/details/125434324
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场