1.软件版本
matlab2021a
2.系统概述
这里,采用的pointnet网络结构如下图所示:
在整体网络结构中,
首先进行set abstraction,这一部分主要即对点云中的点进行局部划分,提取整体特征,如图可见,在set abstraction中,主要有Sampling layer、Grouping layer、以及PointNet layer三层构成,sampling layer即完成提取中心点工作,采用fps算法,而在grouping中,即完成group操作,采用mrg或msg方法,最后对于提取出得点,使用pointnet进行特征提取。在msg中,第一层set abstraction取中心点512个,半径分别为0.1、0.2、0.4,每个圈内的最大点数为16,32,128。
Sampling layer
采样层在输入点云中选择一系列点,由此定义出局部区域的中心。采样算法使用迭代最远点采样方法 iterative farthest point sampling(FPS)。先随机选择一个点,然后再选择离这个点最远的点作为起点,再继续迭代,直到选出需要的个数为止相比随机采样,能更完整得通过区域中心点采样到全局点云
Grouping layer
目的是要构建局部区域,进而提取特征。思想就是利用临近点,并且论文中使用的是neighborhood ball,而不是KNN,是因为可以保证有一个fixed region scale,主要的指标还是距离distance。
Pointnet layer
在如何对点云进行局部特征提取的问题上,利用原有的Pointnet就可以很好的提取点云的特征,由此在Pointnet++中,原先的Pointnet网络就成为了Pointnet++网络中的子网络,层级迭代提取特征。
3.部分核心程序
-
clc;
-
clear;
-
close all;
-
warning off;
-
addpath(genpath(pwd));
-
rng(
'default')
-
%****************************************************************************
-
%更多关于matlab和fpga的搜索“fpga和matlab”的CSDN博客:
-
%matlab/FPGA项目开发合作
-
%https://blog.csdn.net/ccsss22?type=blog
-
%****************************************************************************
-
dsTrain = PtCloudClassificationDatastore(
'train');
-
dsVal = PtCloudClassificationDatastore(
'test');
-
-
ptCloud = pcread(
'Chair.ply');
-
label =
'Chair';
-
figure;pcshow(ptCloud)
-
xlabel(
"X");ylabel(
"Y");zlabel(
"Z");title(label)
-
-
dsLabelCounts = transform(dsTrain,@(data){data{
2} data{
1}.Count});
-
labelCounts = readall(dsLabelCounts);
-
labels = vertcat(labelCounts{:,
1});
-
counts = vertcat(labelCounts{:,
2});
-
figure;histogram(labels);title(
'class distribution')
-
-
-
rng(
0)
-
[G,classes] = findgroups(labels);
-
numObservations = splitapply(@numel,labels,G);
-
desiredNumObservationsPerClass = max(numObservations);
-
filesOverSample=[];
-
for i=
1:numel(classes)
-
if i==
1
-
targetFiles = {dsTrain.Files{
1:numObservations(i)}};
-
else
-
targetFiles = {dsTrain.Files{numObservations(i-
1)+
1:sum(numObservations(
1:i))}};
-
end
-
% Randomly replicate the point clouds belonging to the infrequent classes
-
files = randReplicateFiles(targetFiles,desiredNumObservationsPerClass);
-
filesOverSample = vertcat(filesOverSample,files');
-
end
-
dsTrain.Files=filesOverSample;
-
-
-
-
dsTrain.Files = dsTrain.Files(randperm(length(dsTrain.Files)));
-
-
-
-
dsTrain.MiniBatchSize =
32;
-
dsVal.MiniBatchSize = dsTrain.MiniBatchSize;
-
-
-
dsTrain = transform(dsTrain,@augmentPointCloud);
-
-
data = preview(dsTrain);
-
ptCloud = data{
1,
1};
-
label = data{
1,
2};
-
-
figure;pcshow(ptCloud.Location,[
0
0
1],
"MarkerSize",
40,
"VerticalAxisDir",
"down")
-
xlabel(
"X");ylabel(
"Y");zlabel(
"Z");title(label)
-
-
-
minPointCount = splitapply(@min,counts,G);
-
maxPointCount = splitapply(@max,counts,G);
-
meanPointCount = splitapply(@(x)round(mean(x)),counts,G);
-
stats = table(classes,numObservations,minPointCount,maxPointCount,meanPointCount)
-
-
numPoints =
1000;
-
dsTrain = transform(dsTrain,@(data)selectPoints(data,numPoints));
-
dsVal = transform(dsVal,@(data)selectPoints(data,numPoints));
-
-
dsTrain = transform(dsTrain,@preprocessPointCloud);
-
dsVal = transform(dsVal,@preprocessPointCloud);
-
-
data = preview(dsTrain);
-
figure;pcshow(data{
1,
1},[
0
0
1],
"MarkerSize",
40,
"VerticalAxisDir",
"down");
-
xlabel(
"X");ylabel(
"Y");zlabel(
"Z");title(data{
1,
2})
-
-
-
inputChannelSize =
3;
-
hiddenChannelSize1 = [
64,
128];
-
hiddenChannelSize2 =
256;
-
[parameters.InputTransform, state.InputTransform] = initializeTransform(inputChannelSize,hiddenChannelSize1,hiddenChannelSize2);
-
-
inputChannelSize =
3;
-
hiddenChannelSize = [
64
64];
-
[parameters.SharedMLP1,state.SharedMLP1] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);
-
-
inputChannelSize =
64;
-
hiddenChannelSize1 = [
64,
128];
-
hiddenChannelSize2 =
256;
-
[parameters.FeatureTransform, state.FeatureTransform] = initializeTransform(inputChannelSize,hiddenChannelSize,hiddenChannelSize2);
-
-
inputChannelSize =
64;
-
hiddenChannelSize =
64;
-
[parameters.SharedMLP2,state.SharedMLP2] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);
-
-
-
inputChannelSize =
64;
-
hiddenChannelSize = [
512,
256];
-
numClasses = numel(classes);
-
[parameters.ClassificationMLP, state.ClassificationMLP] = initializeClassificationMLP(inputChannelSize,hiddenChannelSize,numClasses);
-
-
numEpochs =
60;
-
learnRate =
0.001;
-
l2Regularization =
0.1;
-
learnRateDropPeriod =
15;
-
learnRateDropFactor =
0.5;
-
-
gradientDecayFactor =
0.9;
-
squaredGradientDecayFactor =
0.999;
-
avgGradients = [];
-
avgSquaredGradients = [];
-
-
[lossPlotter, trainAccPlotter,valAccPlotter] = initializeTrainingProgressPlot;
-
% Number of classes
-
numClasses = numel(classes);
-
% Initialize the iterations
-
iteration =
0;
-
% To calculate the time for training
-
start = tic;
-
% Loop over the epochs
-
for epoch =
1:numEpochs
-
-
% Reset training and validation datastores.
-
reset(dsTrain);
-
reset(dsVal);
-
-
% Iterate through data set.
-
while hasdata(dsTrain)
% if no data to read, exit the loop to start the next epoch
-
iteration = iteration +
1;
-
% Read data.
-
data = read(dsTrain);
-
% Create batch.
-
[XTrain,YTrain] = batchData(data,classes);
-
% Evaluate the model gradients and loss using dlfeval and the
-
% modelGradients function.
-
[gradients, loss, state, acc] = dlfeval(@modelGradients,XTrain,YTrain,parameters,state);
-
% L2 regularization.
-
gradients = dlupdate(@(g,p) g + l2Regularization*p,gradients,parameters);
-
% Update the network parameters using the Adam optimizer.
-
[parameters, avgGradients, avgSquaredGradients] = adamupdate(parameters, gradients, ...
-
avgGradients, avgSquaredGradients, iteration,learnRate,gradientDecayFactor, squaredGradientDecayFactor);
-
% Update the training progress.
-
D = duration(
0,
0,toc(start),
"Format",
"hh:mm:ss");
-
title(lossPlotter.Parent,
"Epoch: " + epoch +
", Elapsed: " + string(D))
-
addpoints(lossPlotter,iteration,double(gather(extractdata(loss))))
-
addpoints(trainAccPlotter,iteration,acc);
-
drawnow
-
end
-
-
% Create confusion matrix
-
cmat = sparse(numClasses,numClasses);
-
% Classify the validation data to monitor the tranining process
-
while hasdata(dsVal)
-
data = read(dsVal);
% Get the next batch of data.
-
[XVal,YVal] = batchData(data,classes);
% Create batch.
-
% Compute label predictions.
-
isTrainingVal =
0;
%Set at zero for validation data
-
YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
-
-
% Choose prediction with highest score as the class label for
-
% XTest.
-
[~,YValLabel] = max(YVal,[],
1);
-
[~,YPredLabel] = max(YPred,[],
1);
-
cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);
% Update the confusion matrix
-
end
-
% Update training progress plot with average classification accuracy.
-
acc = sum(diag(cmat))./sum(cmat,
"all");
-
addpoints(valAccPlotter,iteration,acc);
-
% Update the learning rate
-
if mod(epoch,learnRateDropPeriod) ==
0
-
learnRate = learnRate * learnRateDropFactor;
-
end
-
reset(dsTrain);
% Reset the training data since all the training data were already read
-
% Shuffle the data at every epoch
-
dsTrain.UnderlyingDatastore.Files = dsTrain.UnderlyingDatastore.Files(randperm(length(dsTrain.UnderlyingDatastore.Files)));
-
reset(dsVal);
-
end
-
-
-
cmat = sparse(numClasses,numClasses);
% Prepare sparse-double variable to do like zeros(2,2)
-
reset(dsVal);
% Reset the validation data
-
data = readall(dsVal);
% Read all validation data
-
[XVal,YVal] = batchData(data,classes);
% Create batch.
-
% Classify the validation data using the helper function pointnetClassifier
-
YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
-
% Choose prediction with highest score as the class label for
-
% XTest.
-
[~,YValLabel] = max(YVal,[],
1);
-
[~,YPredLabel] = max(YPred,[],
1);
-
-
% Collect confusion metrics.
-
cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);
-
figure;chart = confusionchart(cmat,classes);
-
-
acc = sum(diag(cmat))./sum(cmat,
"all")
-
-
-
-
-
-
4.仿真结论
5.参考文献
[1][1] Qi C R , Su H , Mo K , et al. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation[C]// 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2017.资源同名下载
转载:https://blog.csdn.net/ccsss22/article/details/125434324