1. LSTM
import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.python.ops import array_ops
class lstm(object):
def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):
self.in_data = in_data
self.hidden_dim = hidden_dim
self.batch_seqlen = batch_seqlen
self.flag = flag
lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim)
out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32)
if flag=='all_ht':
self.out = out
if flag = 'first_ht':
self.out = out[:,0,:]
if flag = 'last_ht':
self.out = out[:,-1,:]
if flag = 'concat':
self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)
2. Bi-LSTM
import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.python.ops import array_ops
from tensorflow.python.framework import dtypes
class bilstm(object):
def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):
self.in_data = in_data
self.hidden_dim = hidden_dim
self.batch_seqlen = batch_seqlen
self.flag = flag
lstm_cell_fw = contrib.rnn.LSTMCell(self.hidden_dim)
lstm_cell_bw = contrib.rnn.LSTMCell(self.hidden_dim)
out, state = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_cell_fw,cell_bw=lstm_cell_bw,inputs=self.in_data, sequence_lenth=self.batch_seqlen,dtype=tf.float32)
bi_out = tf.concat(out, 2)
if flag=='all_ht':
self.out = bi_out
if flag=='first_ht':
self.out = bi_out[:,0,:]
if flag=='last_ht':
self.out = tf.concat([state[0].h,state[1].h], 1)
if flag=='concat':
self.out = tf.concat([bi_out[:,0,:],tf.concat([state[0].h,state[1].h], 1)],1)