Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
162 views
in Technique[技术] by (71.8m points)

lstm - How to set the variables of LSTMCell as input instead of letting it create it in Tensorflow?

When I create a tf.contrib.rnn.LSTMCell, it creates its kernel and bias trainable variables during initialisation.

How the code looks now:

cell_fw = tf.contrib.rnn.LSTMCell(hidden_size_char,
                        state_is_tuple=True)

What I want it to look:

kernel = tf.get_variable(...)
bias = tf.get_variable(...)
cell_fw = tf.contrib.rnn.LSTMCell(kernel, bias, hidden_size,
                        state_is_tuple=True)

What I want to do is to create those variables myself, and give it to the LSTMCell class when instantiating it as input to its init.

Is there an easy way to do this? I looked at the class source code but it seems that it is within a complex hierarchy of classes.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

I subclassed the LSTMCell class, and changed its init and build methods so that they accept given variables. If variables are given in init within build, we wouldn't use get_variable anymore, and would use the given kernel and bias variables.

There might be cleaner ways to do it though.

_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"

class MyLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, num_units,
                 use_peepholes=False, cell_clip=None,
                 initializer=None, num_proj=None, proj_clip=None,
                 num_unit_shards=None, num_proj_shards=None,
                 forget_bias=1.0, state_is_tuple=True,
                 activation=None, reuse=None, name=None, var_given=False, kernel=None, bias=None):

        super(MyLSTMCell, self).__init__(num_units,
                 use_peepholes=use_peepholes, cell_clip=cell_clip,
                 initializer=initializer, num_proj=num_proj, proj_clip=proj_clip,
                 num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards,
                 forget_bias=forget_bias, state_is_tuple=state_is_tuple,
                 activation=activation, reuse=reuse, name=name)

        self.var_given = var_given
        if self.var_given:
            self._kernel = kernel
            self._bias = bias


    def build(self, inputs_shape):
        if inputs_shape[1].value is None:
            raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                             % inputs_shape)

        input_depth = inputs_shape[1].value
        h_depth = self._num_units if self._num_proj is None else self._num_proj
        maybe_partitioner = (
            partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
            if self._num_unit_shards is not None
            else None)
        if self.var_given:
            # self._kernel and self._bais are already added in init
            pass
        else:
            self._kernel = self.add_variable(
                _WEIGHTS_VARIABLE_NAME,
                shape=[input_depth + h_depth, 4 * self._num_units],
                initializer=self._initializer,
                partitioner=maybe_partitioner)
            self._bias = self.add_variable(
                _BIAS_VARIABLE_NAME,
                shape=[4 * self._num_units],
                initializer=init_ops.zeros_initializer(dtype=self.dtype))
        if self._use_peepholes:
            self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units],
                                               initializer=self._initializer)

        if self._num_proj is not None:
            maybe_proj_partitioner = (
                partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
                if self._num_proj_shards is not None
                else None)
            self._proj_kernel = self.add_variable(
                "projection/%s" % _WEIGHTS_VARIABLE_NAME,
                shape=[self._num_units, self._num_proj],
                initializer=self._initializer,
                partitioner=maybe_proj_partitioner)

        self.built = True

So the code will be like this:

kernel = get_variable(...)
bias = get_variable(...)
lstm_fw = MyLSTMCell(....., var_given=True, kernel=kernel, bias=bias)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...