Skip to content

Latest commit

 

History

History
26 lines (21 loc) · 1.12 KB

README.md

File metadata and controls

26 lines (21 loc) · 1.12 KB

基于 peepholes LSTM 的ConvLSTM

由于contrib.rnn.ConvLSTMCell中对于ConvLSTMCell的实现本没有基于原作者的所引用的带有 "peepholes connection"的LSTM。因此,这里就照着葫芦画瓢,直接在原来的contrib.rnn.ConvLSTMCellcall()实现中上添加了peepholes这一步。

添加的代码为:

		kernel_shape = cell.shape.as_list()[-3:]
		w_ci = vs.get_variable(
			"w_ci", kernel_shape, inputs.dtype)
		w_cf = vs.get_variable(
			"w_cf", kernel_shape, inputs.dtype)
		w_co = vs.get_variable(
			"w_co", kernel_shape, inputs.dtype)

        new_cell = math_ops.sigmoid(forget_gate + self._forget_bias + w_cf * cell) * cell
        new_cell += math_ops.sigmoid(input_gate + w_ci * cell) * math_ops.tanh(new_input)
        output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate + w_co * new_cell)

引用时,将 ConvLSTM中的BasicConvLSTM导入即可:

from ConvLSTM import BasicConvLSTM

用法同ConvLSTMCell一模一样!

循环神经网络系列(七)Tensorflow中ConvLSTMCell