飞道的博客

Mindspore网络构建

524人阅读  评论(0)

网络构建

神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。一个神经网络模型表示为一个Cell,它由不同的子Cell构成。使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。

 

构建Mnist数据集分类的神经网络


  
  1. import mindspore
  2. from mindspore import nn, ops

 个人理解:在代码层面也就是直接调用模块,通过模块来实现我们想要达成的效果。

定义模型类

定义神经网络时,可以继承nn.Cell类,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。


  
  1. class Network(nn.Cell):
  2. def __init__( self):
  3. super().__init__()
  4. self.flatten = nn.Flatten()
  5. self.dense_relu_sequential = nn.SequentialCell(
  6. nn.Dense( 28* 28, 512),
  7. nn.ReLU(),
  8. nn.Dense( 512, 512),
  9. nn.ReLU(),
  10. nn.Dense( 512, 10)
  11. )
  12. def construct( self, x):
  13. x = self.flatten(x)
  14. logits = self.dense_relu_sequential(x)
  15. return logits
  16. #构建完成后,实例化Network对象,并查看其结构。
  17. model = Network()
  18. print(model)
  19. Network<
  20. (flatten): Flatten<>
  21. (dense_relu_sequential): SequentialCell<
  22. ( 0): Dense<input_channels= 784, output_channels= 512, has_bias= True>
  23. ( 1): ReLU<>
  24. ( 2): Dense<input_channels= 512, output_channels= 512, has_bias= True>
  25. ( 3): ReLU<>
  26. ( 4): Dense<input_channels= 512, output_channels= 10, has_bias= True>
  27. >
  28. >
  29. #我们构造一个输入数据,直接调用模型,可以获得一个10维的Tensor输出,其包含每个类别的原始预测值。
  30. X = ops.ones(( 1, 28, 28), mindspore.float32)
  31. logits = model(X)
  32. print(logits)
  33. pred_probab = nn.Softmax(axis= 1)(logits)
  34. y_pred = pred_probab.argmax( 1)
  35. print( f"Predicted class: {y_pred}")

模型层

分解上节构造的神经网络模型中的每一层。


  
  1. input_image = ops.ones(( 5, 15, 18), mindspore.float32)
  2. print(input_image.shape)
  3. #输出结果
  4. ( 5, 15, 18)
  5. #nn.Flatten层的实例化
  6. flatten = nn.Flatten()
  7. flat_image = flatten(input_image)
  8. print(flat_image.shape)
  9. #nn.Dense全链层,权重和偏差对输入进行线性变换
  10. layer1 = nn.Dense(in_channels= 20* 20, out_channels= 20)
  11. hidden1 = layer1(flat_image)
  12. print(hidden1.shape)
  13. #nn.ReLU层,网络中加入非线性的激活函数
  14. print( f"Before ReLU: {hidden1}\n\n")
  15. hidden1 = nn.ReLU()(hidden1)
  16. print( f"After ReLU: {hidden1}")
  17. #nn.SequentialCell容器配置
  18. seq_modules = nn.SequentialCell(
  19. flatten,
  20. layer1,
  21. nn.ReLU(),
  22. nn.Dense( 15, 10)
  23. )
  24. logits = seq_modules(input_image)
  25. print(logits.shape)
  26. #nn.Softmax全链层返回的值进行预测
  27. softmax = nn.Softmax(axis= 1)
  28. pred_probab = softmax(logits)

参数模型

网络内部神经网络层具有权重参数和偏置参数


  
  1. print( f"Model structure: {model}\n\n")
  2. for name, param in model.parameters_and_names():
  3. print( f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")

内置神经网络(mindspore.nn)

1.基本构成单元

接口名 概述
mindspore.nn.Cell MindSpore中神经网络的基本构成单元。
mindspore.nn.GraphCell 运行从MindIR加载的计算图。
mindspore.nn.LossBase 损失函数的基类。
mindspore.nn.Optimizer 用于参数更新的优化器基类。

2.循环神经网络层

接口名 概述
mindspore.nn.RNN 循环神经网络(RNN)层,其使用的激活函数为tanh或relu。
mindspore.nn.RNNCell 循环神经网络单元,激活函数是tanh或relu。
mindspore.nn.GRU GRU(Gate Recurrent Unit)称为门控循环单元网络,是循环神经网络(Recurrent Neural Network, RNN)的一种。
mindspore.nn.GRUCell GRU(Gate Recurrent Unit)称为门控循环单元。
mindspore.nn.LSTM 长短期记忆(LSTM)网络,根据输出序列和给定的初始状态计算输出序列和最终状态。
mindspore.nn.LSTMCell 长短期记忆网络单元(LSTMCell)。

3.嵌入层

接口名 概述
mindspore.nn.Embedding 嵌入层。
mindspore.nn.EmbeddingLookup 嵌入查找层。
mindspore.nn.MultiFieldEmbeddingLookup 根据指定的索引和字段ID,返回输入Tensor的切片。

4.池化层

接口名 概述
mindspore.nn.AdaptiveAvgPool1d 对输入的多维数据进行一维平面上的自适应平均池化运算。
mindspore.nn.AdaptiveAvgPool2d 二维自适应平均池化。
mindspore.nn.AdaptiveAvgPool3d 三维自适应平均池化。
mindspore.nn.AdaptiveMaxPool1d 对输入的多维数据进行一维平面上的自适应最大池化运算。
mindspore.nn.AdaptiveMaxPool2d 二维自适应最大池化运算。
mindspore.nn.AvgPool1d 对输入的多维数据进行一维平面上的平均池化运算。
mindspore.nn.AvgPool2d 对输入的多维数据进行二维的平均池化运算。
mindspore.nn.MaxPool1d 对时间数据进行最大池化运算。
mindspore.nn.MaxPool2d 对输入的多维数据进行二维的最大池化运算。

5. 图像处理层

接口名 概述
mindspore.nn.CentralCrop 根据指定比例裁剪出图像的中心区域。
mindspore.nn.ImageGradients 计算每个颜色通道的图像渐变,返回为两个Tensor,分别表示高和宽方向上的变化率。
mindspore.nn.MSSSIM 多尺度计算两个图像之间的结构相似性(SSIM)。
mindspore.nn.PSNR 在批处理中计算两个图像的峰值信噪比(PSNR)。
mindspore.nn.ResizeBilinear 使用双线性插值调整输入Tensor为指定的大小。
mindspore.nn.SSIM 计算两个图像之间的结构相似性(SSIM)。

因为篇幅原因,这里就不全部介绍了,后面会继续更新


转载:https://blog.csdn.net/weixin_50481708/article/details/127875332
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场