本文内容主要来自阿里云的人工智能学习路线系列课程中的Tensorflow课程。将课程中的代码移植到了TensorFlow2.1中,并对课程中的知识做了补充。
文章目录
两种在tensorflow 2.x版本运行1.x代码的方法
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tensorflow基本结构
def tensorflow_demo():
'''
Tensorflow 基本结构
'''
# 原生Python的加法运算
a,b = 2,3
print("普通加法运算的结果:\n",a+b)
# tensorflow实现加法运算
# 构建图
a_t,b_t = tf.constant(2),tf.constant(3)
print("Tensorflow加法运算结果:\n",a_t+b_t)
# 开启会话
with tf.compat.v1.Session() as sess:
print("c_t_value:\n",sess.run(a_t+b_t))
return None
tensorflow_demo()
输出:
普通加法运算的结果:
5
Tensorflow加法运算结果:
Tensor("add:0", shape=(), dtype=int32)
c_t_value:
5
2.TensorFlow 结构
tensorflow 程序常被组织成一个构图阶段和一个执行图阶段。
在构建阶段,数据与操作的执行被描述成一个图。
流程图:定义数据结构(张量tensor)和操作(节点operation)
在执行阶段,使用会话执行构建好的图中的操作。调用各方资源,将定义好的数据和操作运行起来。
- 图和会话:
- 图:tensorflow将计算表示为指令之间的依赖关系;
- 会话:tensorflow跨一个或多个本地或远程设备运行数据流图的机制。
- 张量:tensorflow中的基本数据对象
- 节点:提供图中执行的操作。在数据流图中,节点通常以圆、椭圆或方框表示,代表对数据的运算或某种操作。
2.1 数据流图
构建数据流图时,需要两个基础元素:点(node)和边(edge)。
- 节点:在数据流图中,节点通常以圆、椭圆或方框表示,代表对数据的运算或某种操作。例如,在图11-26中,就有5个节点,分别表示输入(input)、乘法(mul)和加法(add)。
- 边:数据流图是一种有向图,“边”通常用带箭头线段表示,实际上,它是节点之间的连接。指向节点的边表示输入,从节点引出的边表示输出。输入可以是来自其他数据流图,也可以表示文件读取、用户输入。输出就是某个节点的操作(Operation,下文简称Op)结果。
# 设置日志等级以屏蔽警告信息
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
log信息共有四个等级,按重要性递增为:
INFO(通知) < WARNING(警告) < ERROR(错误) < FATAL(致命错误);
os.environ['TF_CPP_MIN_LOG_LEVEL']
值的含义:不同值设置的是基础log信息(base_loging),运行时会输出base等级及其之上(更为严重)的信息。具体如下:
base_loging | 屏蔽信息 | 输出信息 | |
---|---|---|---|
“0” | INFO | 无 | INFO + WARNING + ERROR + FATAL |
“1” | WARNING | INFO | WARNING + ERROR + FATAL |
“2” | ERROR | INFO + WARNING | ERROR + FATAL |
“3” | FATAL | INFO + WARNING + ERROR | FATAL |
2.2 图与tensorboard
2.2.1 什么是图结构?
图结构:数据(tensor)+操作(operation)
2.2.2 图的相关操作
1.默认图。查看默认图的方法:
1)调用方法
tf.compat.v1.get_default_graph()
2)查看属性 .graph
2.创建图
使用tf.Graph()
方法创建自定义图
new_g = tf.Graph()
with new_g.as_default()
:定义数据和操作
with tf.Session(graph=new_g) as new_sess
:
def graph_demo():
'''
图的演示
:return:
'''
# tensorflow实现加法运算
# 构建图
a_t,b_t = tf.constant(2),tf.constant(3)
c_t = a_t+b_t
# a_t,b_t打印输出:
'''
a_t:
Tensor("Const_26:0", shape=(), dtype=int32)
b_t:
Tensor("Const_27:0", shape=(), dtype=int32)
'''
# 其中Const_26为指令名称,与tensorboard中graph显示的一致;
# 操作函数constant在运行过程中会生成一个操作对象operation,z冒号后边的“0”表示constant操作对象operation的输出个数为1个。
# 如果是1,则表示2个,输出索引从0开始。
print("a_t:\n",a_t)
print("b_t:\n",b_t)
print("c_t:\n",c_t )
# 查看默认图
# 方法一:调用方法
default_g = tf.compat.v1.get_default_graph()
print('a_t的图属性:\n',default_g)
# 方法二:查看属性
print('a_t的图属性:\n',a_t.graph)
# 开启默认会话,tf.compat.v1.Session()运行默认图中的操作
with tf.compat.v1.Session() as sess:
print("c_t_value:\n",sess.run(c_t))
# 也可用eval()方法,但必须在session中。多用在交互模式下。
# print("c_t_value:\n",c_t.eval())
print('sess的图属性:\n',sess.graph)
# tensorboard
tf.compat.v1.summary.FileWriter('D:/AliyunEDU/04 summary/',graph=sess.graph)
#-------------------------------------------------------
# 自定义图
new_g = tf.Graph()
with new_g.as_default():
a_new,b_new = tf.constant(20),tf.constant(30)
c_new = a_new+b_new
print("c_new:\n",c_new)
print('c_new的图属性:\n',c_new.graph)
# 开启new_g的会话
with tf.compat.v1.Session(graph=new_g) as new_sess:
c_new_value = new_sess.run(c_new)
print('c_new_value:\n',c_new_value)
print('new_sess的图属性:\n',new_sess.graph)
# 两种图的地址不同
# 每张图都有自己的命名空间
graph_demo()
输出:
a_t:
Tensor("Const_28:0", shape=(), dtype=int32)
b_t:
Tensor("Const_29:0", shape=(), dtype=int32)
c_t:
Tensor("add_15:0", shape=(), dtype=int32)
a_t的图属性:
<tensorflow.python.framework.ops.Graph object at 0x000002D94549C188>
a_t的图属性:
<tensorflow.python.framework.ops.Graph object at 0x000002D94549C188>
c_t_value:
5
sess的图属性:
<tensorflow.python.framework.ops.Graph object at 0x000002D94549C188>
c_new:
Tensor("add:0", shape=(), dtype=int32)
c_new的图属性:
<tensorflow.python.framework.ops.Graph object at 0x000002D9470E87C8>
c_new_value:
50
new_sess的图属性:
<tensorflow.python.framework.ops.Graph object at 0x000002D9470E87C8>
2.2.3 tensorboard
-
1.数据序列化
tensorboard通过读取tensorflow的事件文件来运行,需要将数据生成一个序列化的summary protobuf对象。tf.summary.FileWriter(path,grap=sess.graph)
-
2.启动tensorboard
tensorboard --logdir=path
windows示例: 注意是反斜杠和双引号!
tensorboard --logdir="D:\AliyunEDU\04 summary"
2.2.4 Op
数据:tensor对象
操作:operation对象 -Op
1.常见Op如下表
类型 | 实例 |
---|---|
标量运算 | add, sub, mul, div, exp, log, greater, less, equal |
向量运算 | concat, slice, splot. constant, rank, shape, shuffle |
矩阵运算 | matmul, matrixinverse, matrixdateminant |
带状态的运算 | Variable, assgin, assginadd |
神经网络组件 | softmax, sigmoid relu,convolution,max_pool |
存储,恢复 | Save, Restroe |
队列及同步运算 | Enqueue, Dequeue, MutexAcquire, MutexReiease |
控制流 | Merge, Switch, Enter, Leave, NextIteration |
2.操作函数与操作对象
操作函数 | 操作对象 |
---|---|
tf.constant(tensor对象) | 输入tensor对象 -Const -输出tensor对象 |
tensor.add(tensor对象1,tensor对象2) | 输入tensor对象1,tensor对象2 -Add对象 -输出tensor对象3 |
一个操作对象(Operation)是tensorflow图中的一个节点,可以接收0个或者多个输入tensor,并且可以输出0个或者多个tensor,operation对象是通过Op构造函数(如tf.matual())创建的。
例如:c = tf.constant(3.0)
创建了一个Operation对象,类型为matmul类型,它将张量a,b作为输入,c作为输出,并且输出数据,打印的时候也是打印数据。其中,tf.matual()是函数,在执行matmul函数的过程中会通过matmul类创建一个与之对应的对象。
注意,tf.Tensor 对象以输出该张量的tf.Operation明确命名。张量名称的形式为“<OP_NAME>:<i>
”,其中:
- “
<OP_NAME>
”是生成该张量的指令的名称 - “
<i>
”是一个整数,它表示该张量在指令分输出中的索引
3.指令名称
tf.Graph
对象为其包含的tf.operation对象定义的一个命名空间。tensorflow会自动为图中的每个指令选择一个唯一的名称,用户也可以指定描述性名称,使程序阅读起来更轻松。我们可用操作方法中的name
参数改写指令名称。例如:a_t = tf.constant(2,name=‘a_t’)。
2.3 Session
tf.Session
:用于完整的程序当中。tensorflow使用tf.Session类来表示客户端(通常为Python程序)与C++运行时之间的连接。
tf.interactiveSession
:用于交互式上下文中的tensor。
__init__(target='',graph=None,config=None)
会话可能拥有的资源,如tf.Variable
,tf.QueueBase
和tf.readerBase
。当这些资源不再需要时,释放这些资源非常重要。因此,需要调用tf.Session.close
会话中的方法,或将会话用作上下文管理器。target
:如果将此参数留空(默认设置),会话将使用本地计算机中的设备。可以指定grpc://网址,以便确定tensorflow服务器的地址,这使得会话可以访问该服务器控制的计算机上的所有设备。graph=None
: 运行默认图config
:此参数允许在制定一个tf.ConfigProto以便控制会话的行为。例如,ConfigProto协议用于打印设备使用信息。
def session_demo():
'''
会话的演示
:return:
'''
a_t,b_t = tf.constant(2),tf.constant(3)
c_t = a_t + b_t
# 运行会话并打印设备信息
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True,
log_device_placement=True)) as sess:
print("c_t_value:\n",sess.run(c_t))
session_demo()
- 会话中的
run()
run(fetches,feed_dict=None,options=None.run_metadata=None)
- 通过run()来运行operation
fetches
:单一的operation,或者列表、元组(其他不属于tensorflow的类型不行)。feed_dict
:参数允许调用者覆盖图中张量的值,运行时赋值。与tf.placeholder
搭配使用,则会检查值得形状是否与占位符兼容。
接上例,同时查看多个值:
a,b,c = sess.run([a_t,b_t,c_t]) # 列表
print(a,b,c)
feed()
- placeholder提供占位符,run时候通过feed_dict指定参数。未知数据样本维度时,可以先占位;当明确样本维度时,再用feed_dict加载。
def session_feed():
a_ph = tf.compat.v1.placeholder(tf.float32)
b_ph = tf.compat.v1.placeholder(tf.float32)
c_ph = tf.add(a_ph,b_ph)
print('c_ph:',c_ph)
with tf.compat.v1.Session() as sess:
# 运行placeholder
c_ph_value = sess.run(c_ph, feed_dict={a_ph:2, b_ph:3})
print("c_t_value:\n",c_ph_value)
session_feed()
输出:
c_ph: Tensor("Add_23:0", dtype=float32)
c_t_value:
5.0
2.4 张量 Tensor
tensorflow的张量就是一个你维数组,类型为tf.tensor。tensor具有以下两个重要的属性:
- type:数据类型
- shape:形状(阶)
2.4.1 张量的阶
阶 | 数学实例 | Python | 例子 |
---|---|---|---|
0 | 纯量 | 只有大小 | a=64 |
1 | 向量 | 大小和方向 | v=[1,2,3] |
2 | 矩阵 | 数据表 | m=[[1,2,3],[4,5,6]] |
3 | 3阶张量 | 数据立体 | t=[[[1],[2]],[[3],[4]]] |
n | n阶张量 | … | … |
def tensor_demo():
'''
张量的演示
'''
tensor1 = tf.constant(4.0)
tensor2 = tf.constant([1,2,3])
tensor3 = tf.constant([[4],[5],[6]], dtype=tf.int32)
print("tensor1:\n",tensor1)
print("tensor2:\n",tensor2)
print("tensor3:\n",tensor3)
return None
tensor_demo()
输出:
tensor1:
Tensor("Const_36:0", shape=(), dtype=float32)
tensor2:
Tensor("Const_37:0", shape=(3,), dtype=int32)
tensor3:
Tensor("Const_38:0", shape=(3, 1), dtype=int32)
2.4.2 创建张量的指令
- 固定值张量
tf.zeros(shape, dtype=tf.float32, name=None)
tf.ones(shape, dtype=tf.float32, name=None)
tf.constant(value, dtype=None,shape=None,name='Const')
- 随机值张量
tf.random_normal(shape, mean=0.0, stddev=1.0,dtype=tf.float32,name=None)
2.4.3 张量的变换
回顾:ndarray属性的修改
- 类型的修改
ndarray.astype(type)
tf.cast(tensor,dtype)
:不会改变原始的tensor,返回新的改变类型后的tensor。ndarray.tostring()
- 形状的修改
-
ndarray.reshape(shape)
- 自动计算形状
- 返回一个新的数组
-
ndarray.resize(shape) – 在原数组上修改
1)如何改变静态形状?- 什么情况下采可以改变/更新静态形状?
只有在形状没有完全固定下来的情况下 tensor.set_shape(shape)
:在原来的形状上修改
2)如何改变动态形状?
tf.reshape(tensor,type)
:返回一个新的数组,可以改变形状(行、列、维度),但不能改变元素数量!
- 什么情况下采可以改变/更新静态形状?
tensorflow 类型修改
tf.to_string_to_number(string_tensor, out_type=None,name=None)
tf.to_double(x, name='ToDouble')
tf.to_float(x,name='ToFloat')
tf.to_bfloat16(x, name='ToBFloat16')
tf.to_int32(x,name='ToInt32')
tf.to_int64(x,name='ToInt64')
tf.cast(x,dtype,name=None)
def tensor_demo():
'''
张量类型的修改
'''
tensor = tf.constant([[4],[5],[6]], dtype=tf.int32)
tensor_cast = tf.cast(tensor, dtype=tf.float32)
print("tensor3 before:",tensor)
print("tensor3 after:",tensor_cast)
# 更新/改变静态形状
# 没有完全固定下来的静态形状
# shape中为None的维度,可以在以后的更新中改变,其余有固定值的维度无法改变
a_p = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,None])
b_p = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,10])
c_p = tf.compat.v1.placeholder(dtype=tf.float32, shape=[3,2])
print('a_p{}\n b_p{}\n c_p{}\n'.format(a_p,b_p,c_p))
# 更新形状未确定的部分
a_p.set_shape([2,3])
print('a_p_update:',a_p)
# 动态形状修改
a_p = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,None])
a_p_reshape = tf.reshape(a_p,shape=[2,3,1])
print('a_p_reshape:',a_p_reshape)
return None
tensor_demo()
输出:
tensor3 before: Tensor("Const_48:0", shape=(3, 1), dtype=int32)
tensor3 after: Tensor("Cast_9:0", shape=(3, 1), dtype=float32)
a_pTensor("Placeholder_24:0", shape=(None, None), dtype=float32)
b_pTensor("Placeholder_25:0", shape=(None, 10), dtype=float32)
c_pTensor("Placeholder_26:0", shape=(3, 2), dtype=float32)
a_p_update: Tensor("Placeholder_24:0", shape=(2, 3), dtype=float32)
a_p_reshape: Tensor("Reshape_2:0", shape=(2, 3, 1), dtype=float32)
tensorflow 形状修改
tensorflow的张量具有两种形状变换,动态形状和静态形状
tf.reshape
tf.set_shape
2.4.4 张量的数学运算
- 算数运算符
- 基本数学函数
- 矩阵运算
- reduce操作
- 序列索引操作
2.5 变量OP
tensorflow变量是表示程序处理的共享持久状态的最佳方法。变量通过 tf.Variable OP类进行操作。变量的特点:
- 存储持久化
- 可修改值
- 可指定被训练
2.5.1 创建变量
tf.Variable(initial_value=None,trainable=True,collections=None,name=None)
initial_value
:初始化的值trainable
:是否可训练collections
:新变量将添加到列出的图的集合中collections
def variable_demo():
'''
变量的演示
'''
# 定义变量
a = tf.Variable(initial_value=50)
b = tf.Variable(initial_value=40)
c = tf.add(a,b)
# 初始化变量
init = tf.compat.v1.global_variables_initializer()
# 开启会话
with tf.compat.v1.Session() as sess:
# 初始化
sess.run(init)
a_value,b_value,c_value = sess.run([a,b,c])
print("a_value:\n", a_value)
return None
variable_demo()
2.5.2 使用tf.variable_scope()修改命名空间
def variable_scope_demo():
'''
变量的演示
'''
# 使用命名空间定义变量
with tf.compat.v1.variable_scope("my_scope"):
a = tf.Variable(initial_value=50)
b = tf.Variable(initial_value=40)
c = tf.add(a,b)
# 初始化变量
init = tf.compat.v1.global_variables_initializer()
# 开启会话
with tf.compat.v1.Session() as sess:
# 初始化
sess.run(init)
a_value,b_value,c_value = sess.run([a,b,c])
print("a_value:\n", a_value)
return None
variable_scope_demo()
2.6 tensorflow API
2.6.1 基础API
tf.app
这个模块相当于为tensorflow进行的脚本提供一个main函数入口,可以定义脚本运行的flags。tf.image
图像处理操作。主要是一些颜色变换、变形和图像的编码和解码。tf.gfile
提供了一组文件操作函数。summary
用来生成tensorboard可用的统计日志,主要提供了四种类型:audio、image、histogram、scalartf.python_io
用来读写TFRecords文件。tf.train
提供了一些训练器,与tf.nn组合起来,实现一些网络的优化计算。tf.nn
提供了一些构建神经网络的底层函数。比如卷积、池化等操作。
2.6.2 高级API
tf.keras
tf.layers
提供更高级的概念层来定义一个模型。类似tf.kerastf.contrib(tf2.x已经弃用)
tf.contrib.layers提供能够将计算图中网络层、正则化、摘要操作、是构建计算图的高级操作,但包含不稳定和实验代码,有可能会改变。Tensorflow2.x已经弃用!tf.estimator
一个Estimator相当于model+training——evaluate的合体。在模块中,已经实现了几种简单的分类器和回归器,包括:Baseline,learning和dnn。
2.7 实例:实现线性回归训练
2.7.1线性回归
- 1)构建模型:
- 2)构造损失函数:
均方误差 - 3)优化损失:
梯度下降
2.7.2 案例
- 准备真实数据
- x 特征值
- y_true 目标值
- y_true = 0.8x+0.7
- 假定x 和y之间的关系满足:
- (网络训练,学习的两个参数)
流程分析:
(100, 1) * (1,1) = (100,1) y_predict = x * weights(1,1) + bias(1,1)
-
1)构建模型
y_predict = tf.matmul(x, weights) + bias
-
2)构建损失函数
error = tf.reduce_mean(tf.square(y_predict - y_true))
-
3)优化损失
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)
训练过程其实是在更新迭代optimizer。 -
4)学习率的设置、步数的设置与梯度爆炸
学习率越大,训练到较好结果的步数越小;学习率越小,训练到较好结果的步数越大。学习率过大会出现梯度爆炸的现象。
如何解决梯度爆炸?
- 重新设计网络
- 调整学习率
- 使用梯度截断(在训练过程中检查和限制梯度的大小)
- 使用激活函数
实例代码:
def linear_regression():
# 1.准备数据
X = tf.compat.v1.random_normal(shape=[100,1])
y_true = tf.matmul(X, [[0.8]]) + 0.7
# 2.构造模型
# 定义模型参数用 变量Variable
weights = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1,1]))
bias = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1,1]))
y_predict = tf.matmul(X,weights) + bias
# 3.构造损失函数
error = tf.reduce_mean(tf.square(y_predict - y_true))
# 4.优化损失
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.05).minimize(error)
# 2_收集变量
tf.compat.v1.summary.scalar("error",error)
tf.compat.v1.summary.histogram("weights",weights)
tf.compat.v1.summary.histogram("bias",bias)
# 3_合并变量
merged = tf.compat.v1.summary.merge_all()
# 显式的初始化变量
init = tf.compat.v1.global_variables_initializer()
# 开启会话
with tf.compat.v1.Session() as sess:
# 初始化变量
sess.run(init)
# 1_创建事件文件
file_writer = tf.compat.v1.summary.FileWriter("D:/AliyunEDU/04 summary2/",graph=sess.graph)
# 查看初始化模型参数后的值
print("训练前的模型参数为:权重:%f,偏置:%f,损失为:%f" % (weights.eval(), bias.eval(), error.eval()))
# 开始训练
for i in range(200):
sess.run(optimizer)
print("第 %d 次训练后的模型参数为:权重:%f,偏置:%f,损失为:%f" % (i+1,weights.eval(), bias.eval(), error.eval()))
# 4_运行合并变量操作
summary = sess.run(merged)
# 5_将每次迭代后的变量写入事件文件
file_writer.add_summary(summary,i)
return None
linear_regression()
训练200个epoch的输出:
......
第 191 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 192 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 193 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 194 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 195 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 196 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 197 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 198 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 199 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
第 200 次训练后的模型参数为:权重:0.800000,偏置:0.700000,损失为:0.000000
2.7.3增加其他功能
在tensorboard中观察模型的参数、损失值等变量值的变化。
- 收集变量
-
tf.summary.scalar(name='',tensor)
收集对于损失函数和准确率等单值变量,name为变量的名字,tensor为值。 -
tf.summary.histogram(name=''tensor)
收集高维度的变量参数 -
tf.summary.image(name='',tensor)
- 合并变量并写入事件文件
merged = tf.summary.merge_all()
- 运行合并:
summary = sess.run(merged)
每次迭代都运行 - 添加:
FileWriter.add_summary(summary,i)
i表示第几次的值
【总结】:
- 1.创建事件文件
- 2.收集变量
- 3.合并变量
- 4.每次迭代运行一次合并变量
- 5.每次迭代将summary对象写入事件文件
2.7.4 添加命名空间
查看tensorboard的时候,更清晰
def linear_regression_scope():
# 1.准备数据
with tf.compat.v1.variable_scope("prepare_data"):
X = tf.compat.v1.random_normal(shape=[100,1],name="X")
y_true = tf.matmul(X, [[0.8]]) + 0.7
# 2.构造模型
# 定义模型参数用 变量Variable
with tf.compat.v1.variable_scope("create_model"):
weights = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1,1]),name="Weights")
bias = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1,1]),name="Bias")
y_predict = tf.matmul(X,weights) + bias
# 3.构造损失函数
with tf.compat.v1.variable_scope("loss_function"):
error = tf.reduce_mean(tf.square(y_predict - y_true))
# 4.优化损失
with tf.compat.v1.variable_scope("optimizer"):
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.05).minimize(error)
# 2_收集变量
tf.compat.v1.summary.scalar("error",error)
tf.compat.v1.summary.histogram("weights",weights)
tf.compat.v1.summary.histogram("bias",bias)
# 3_合并变量
merged = tf.compat.v1.summary.merge_all()
# 显式的初始化变量
init = tf.compat.v1.global_variables_initializer()
# 开启会话
with tf.compat.v1.Session() as sess:
# 初始化变量
sess.run(init)
# 1_创建事件文件
file_writer = tf.compat.v1.summary.FileWriter("D:/AliyunEDU/04 summary2/",graph=sess.graph)
# 查看初始化模型参数后的值
print("训练前的模型参数为:权重:%f,偏置:%f,损失为:%f" % (weights.eval(), bias.eval(), error.eval()))
# 开始训练
for i in range(200):
sess.run(optimizer)
print("第 %d 次训练后的模型参数为:权重:%f,偏置:%f,损失为:%f" % (i+1,weights.eval(), bias.eval(), error.eval()))
# 4_运行合并变量操作
summary = sess.run(merged)
# 5_将每次迭代后的变量写入事件文件
file_writer.add_summary(summary,i)
return None
linear_regression_scope()
tensorboard显示:
2.7.5 模型的保存与加载
tf.train.Saver(var_list=None,max_to_keep=5)
- 保存和加载模型(保存文件格式:checkpoint)
var_list
:指定将要保存和还原的变量。他可以作为一个dict或一个列表传递。max_to_keep
:指示要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件。如果无或0,则保留所有检查点文件。默认为5。
使用步骤:
saver = tf.train.Saver(var_list=None,max_to_keep=5)
- 保存:
saver.save(sess, path)
#需要先创建path中的文件夹 - 加载:
saver.restore(sess,path)
加载模型:
def linear_regression_scope_save():
# 1.准备数据
with tf.compat.v1.variable_scope("prepare_data"):
X = tf.compat.v1.random_normal(shape=[100,1],name="X")
y_true = tf.matmul(X, [[0.8]]) + 0.7
# 2.构造模型
# 定义模型参数用 变量Variable
with tf.compat.v1.variable_scope("create_model"):
weights = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1,1]),name="Weights")
bias = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1,1]),name="Bias")
y_predict = tf.matmul(X,weights) + bias
# 3.构造损失函数
with tf.compat.v1.variable_scope("loss_function"):
error = tf.reduce_mean(tf.square(y_predict - y_true))
# 4.优化损失
with tf.compat.v1.variable_scope("optimizer"):
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.05).minimize(error)
# 2_收集变量
tf.compat.v1.summary.scalar("error",error)
tf.compat.v1.summary.histogram("weights",weights)
tf.compat.v1.summary.histogram("bias",bias)
# 3_合并变量
merged = tf.compat.v1.summary.merge_all()
# 显式的初始化变量
init = tf.compat.v1.global_variables_initializer()
# 创建saver对象
saver = tf.compat.v1.train.Saver(max_to_keep=5)
# 开启会话
with tf.compat.v1.Session() as sess:
# 初始化变量
sess.run(init)
# 1_创建事件文件
file_writer = tf.compat.v1.summary.FileWriter("D:/AliyunEDU/04 summary2/",graph=sess.graph)
# 查看初始化模型参数后的值
print("训练前的模型参数为:权重:%f,偏置:%f,损失为:%f" % (weights.eval(), bias.eval(), error.eval()))
# 开始训练
# for i in range(200):
# sess.run(optimizer)
# print("第 %d 次训练后的模型参数为:权重:%f,偏置:%f,损失为:%f" % (i+1,weights.eval(), bias.eval(), error.eval()))
# # 4_运行合并变量操作
# summary = sess.run(merged)
# # 5_将每次迭代后的变量写入事件文件
# file_writer.add_summary(summary,i)
# # 保存模型
# if i % 10 == 0:
# saver.save(sess,"D:/AliyunEDU/04 model/lr_model.ckpt")
if os.path.exists("D:/AliyunEDU/04 model/lr_model.ckpt"):
saver.restore(sess,"D:/AliyunEDU/04 model/checkpoint")
print("训练后的模型训练参数为:权重%f,偏置%f,损失%f" % (weights.eval(), bias.eval(), error.eval()))
return None
linear_regression_scope_save()
2.8 命令行参数的使用
面向对象编程。
tf.app.flags
,他支持应用从命令行接受参数,狂热以用来指定集群配置等。在tf.app.flags
下面有各种定义参数的类型。tf.app.flags
有一个FLAGS标志,他在程序中可以调用我们前面具体定义的flag_name
。- 通过
tf.app.run()
启动main(argv)
函数。
转载:https://blog.csdn.net/weixin_39653948/article/details/104917750