The tf.nn.dynamic_rnn()
or tf.nn.rnn()
operations allow to specify the initial state of the RNN using the initial_state
parameter. If you don't specify this parameter, the hidden states will be initialized to zero vectors at the beginning of each training batch.
In TensorFlow, you can wrap tensors in tf.Variable()
to keep their values in the graph between multiple session runs. Just make sure to mark them as non-trainable because the optimizers tune all trainable variables by default.
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)
with tf.control_dependencies([state.assign(new_state)]):
output = tf.identity(output)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})
I haven't tested this code but it should give you a hint in the right direction. There is also a tf.nn.state_saving_rnn()
to which you can provide a state saver object, but I didn't use it yet.
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…