tensorflow实现rnn
import tensorflow as tf
import logging
import time
import numpy as np
class RNNModel(object): #在建立rnn或者encoder_decoder的时候写成类的形式,有利于模型构造和恢复。
#将建立模型的代码写在init方法中
def __init__(self, is_test, hidden_size, rnn_layers, batch_size, seq_length, vocab_size, embedding_size,
learning_rate, max_grad_norm):
self.hidden_size = hidden_size
self.rnn_layers = rnn_layers
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.learning_rate = tf.constant(learning_rate)
if is_test == 'true':
self.batch_size = 1
self.seq_length = 1
cell_fn = tf.contrib.rnn.BasicRNNCell
cell = cell_fn(hidden_size)
cells = [cell]
for i in range(rnn_layers - 1):
higher_layer_cell = cell_fn(hidden_size)
cells.append(higher_layer_cell)
multi_cell = tf.contrib.rnn.MultiRNNCell(cells)
self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32)
self.initial_state = create_tuple_placeholders_with_default(self.zero_state, shape=multi_cell.state_size)
self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.seq_length], name='inputs')
self.targets = tf.placeholder(tf.int64, [self.batch_size, self.seq_length], name='targets')
self.embedding = tf.get_variable('embedding', [vocab_size, embedding_size])
inputs = tf.nn.embedding_lookup(self.embedding, self.input_data)
sliced_inputs = [tf.squeeze(input_, [1]) for input_ in
tf.split(axis=1, num_or_size_splits=self.seq_length, value=inputs)] #通过一个tf.unstack(inputs,self.seq_length,1)就可以解决
outputs, final_state = tf.contrib.rnn.static_rnn(multi_cell, sliced_inputs, initial_state=self.initial_state)
self.final_state = final_state
flat_outputs = tf.reshape(tf.concat(axis=1, values=outputs), [-1, hidden_size])
flat_targets = tf.reshape(tf.concat(axis=1, values=self.targets), [-1])
softmax_w = tf.get_variable("softmax_w", [hidden_size, vocab_size])
softmax_b = tf.get_variable("softmax_b", [vocab_size])
self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b
# self.probs = tf.nn.softmax(self.logits)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=flat_targets)
mean_loss = tf.reduce_mean(loss)
count = tf.Variable(1.0, name='count')
sum_mean_loss = tf.Variable(1.0, name='sum_mean_loss')
update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss + mean_loss), count.assign(count + 1),
name='update_loss_monitor')
with tf.control_dependencies([update_loss_monitor]):
self.average_loss = sum_mean_loss / count
self.global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0.0))
if is_test == 'false':
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(mean_loss, tvars), max_grad_norm)
optimizer = tf.train.AdamOptimizer(self.learning_rate)
self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step)
#训练代码
def train(self, session, train_size, train_batches):
epoch_size = train_size // (self.batch_size * self.seq_length)
if train_size % (self.batch_size * self.seq_length) != 0:
epoch_size += 1
state = session.run(self.zero_state)
start_time = time.time()
for step in range(epoch_size):
data = train_batches.next()
inputs = np.array(data[:-1]).transpose()
targets = np.array(data[1:]).transpose()
ops = [self.average_loss, self.final_state, self.train_op, self.global_step, self.learning_rate]
feed_dict = {self.input_data: inputs, self.targets: targets,
self.initial_state: state}
average_loss, state, __, global_step, lr = session.run(ops, feed_dict)
logging.info("average loss: %.3f, speed: %.0f chars per sec",
average_loss, (step + 1) * self.batch_size * self.seq_length /
(time.time() - start_time))
#测试代码
def predict(self, session, start_text, length, vocab_index_dict, index_vocab_dict):
state = session.run(self.zero_state)
seq = list(start_text)
for char in start_text[:-1]:
x = np.array([[vocab_index_dict[char]]])
state = session.run(self.final_state, {self.input_data: x, self.initial_state: state})
x = np.array([[vocab_index_dict[start_text[-1]]]])
for i in range(length):
state, logits = session.run([self.final_state, self.logits],
{self.input_data: x, self.initial_state: state})
unnormalized_probs = np.exp(logits - np.max(logits))
probs = unnormalized_probs / np.sum(unnormalized_probs) #为了防止数值上溢。是计算softmax概率的经典做法
sample = np.argmax(probs[0])
seq.append(index_vocab_dict[sample])
x = np.array([[sample]])
print(''.join(seq))