小言_互联网的博客

lstm 前后两次调用权重一样

271人阅读  评论(0)
import tensorflow as tf
import numpy as np
# 创建输入数据


cell = tf.contrib.rnn.BasicLSTMCell(num_units=4, state_is_tuple=True)
X = tf.placeholder(tf.float32,(2,10,8))
X_lengths = tf.placeholder(tf.float32)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float32,
    sequence_length=X_lengths,
    inputs=X)
lable = tf.ones((2,10,4))
loss = lable - outputs
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
outputs2, last_states2 = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float32,
    sequence_length=X_lengths[:1],
    inputs=X[:1])

with tf.Session() as sess:
    X1 = np.ones((2, 10, 8))
    # 第二个example长度为6
    
    X_lengths1 = [10, 6]
    # X2 =np.ones((1,10,8))
    # X_lengths2 = [10]
    sess.run(tf.global_variables_initializer())
    for i in range(10):
        sess.run(optimizer,feed_dict={X:X1,X_lengths:X_lengths1})
    outputs, last_states,outputs2,last_states2 = sess.run((outputs, last_states,outputs2,last_states2),feed_dict={X:X1,X_lengths:X_lengths1})
    print(last_states)
    print(outputs)
    print(last_states2)
    print(outputs2)
    print(outputs2-outputs[0])

输出为

LSTMStateTuple(c=array([[3.6967206 , 0.36333948, 2.6907382 , 0.7764288 ],
       [2.9159408 , 0.34533507, 2.3552992 , 0.7168741 ]], dtype=float32), h=array([[0.83917683, 0.16514859, 0.82227004, 0.2351862 ],
       [0.83666867, 0.15811422, 0.8160259 , 0.22364864]], dtype=float32))
[[[0.5070681  0.08897948 0.5242556  0.1298186 ]
  [0.72878003 0.11847612 0.7119349  0.16895662]
  [0.8013882  0.13465777 0.77578974 0.19117852]
  [0.8250746  0.14580028 0.80016226 0.20640413]
  [0.8334302  0.15326948 0.8107919  0.21670982]
  [0.83666867 0.15811422 0.8160259  0.22364864]
  [0.83804226 0.16120268 0.81887186 0.22833543]
  [0.8386776  0.16315424 0.8205446  0.2315192 ]
  [0.8389986  0.16438086 0.8215884  0.23369396]
  [0.83917683 0.16514859 0.82227004 0.2351862 ]]

 [[0.5070681  0.08897948 0.5242556  0.1298186 ]
  [0.72878003 0.11847612 0.7119349  0.16895662]
  [0.8013882  0.13465777 0.77578974 0.19117852]
  [0.8250746  0.14580028 0.80016226 0.20640413]
  [0.8334302  0.15326948 0.8107919  0.21670982]
  [0.83666867 0.15811422 0.8160259  0.22364864]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]
  [0.         0.         0.         0.        ]]]
LSTMStateTuple(c=array([[3.6967204 , 0.36333948, 2.6907384 , 0.7764287 ]], dtype=float32), h=array([[0.8391768 , 0.16514859, 0.8222699 , 0.23518619]], dtype=float32))
[[[0.5070681  0.08897948 0.5242555  0.12981859]
  [0.72878003 0.11847612 0.7119349  0.16895662]
  [0.8013882  0.13465776 0.77578974 0.19117847]
  [0.8250748  0.14580028 0.8001623  0.20640409]
  [0.8334301  0.15326953 0.8107919  0.21670981]
  [0.8366687  0.15811422 0.816026   0.22364862]
  [0.83804226 0.16120267 0.818872   0.22833538]
  [0.8386776  0.16315423 0.8205446  0.2315192 ]
  [0.8389986  0.16438086 0.82158846 0.23369391]
  [0.8391768  0.16514859 0.8222699  0.23518619]]]
[[[ 0.0000000e+00  0.0000000e+00 -5.9604645e-08 -1.4901161e-08]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00 -1.4901161e-08  0.0000000e+00 -4.4703484e-08]
  [ 1.7881393e-07  0.0000000e+00  5.9604645e-08 -4.4703484e-08]
  [-5.9604645e-08  4.4703484e-08  0.0000000e+00 -1.4901161e-08]
  [ 5.9604645e-08  0.0000000e+00  5.9604645e-08 -1.4901161e-08]
  [ 0.0000000e+00 -1.4901161e-08  1.1920929e-07 -4.4703484e-08]
  [ 0.0000000e+00 -1.4901161e-08  0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  5.9604645e-08 -4.4703484e-08]
  [-5.9604645e-08  0.0000000e+00 -1.1920929e-07 -1.4901161e-08]]]

创建两个lstm
第一个输入为[2,10,8]
第二个的输入为第一个输入的第一维
前后两次调用结果一样
rnn权重共享


转载:https://blog.csdn.net/weixin_44125720/article/details/102391778
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场