Provide three basic types of recurrent nerual network, and training history plot and prediction plot.
Standard X_train, y_train, X_test, y_test
from keras.models import Sequential
from keras.layers import SimpleRNN, LSTM, GRU, Dense
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.models import load_model
class RNN():
def __init__(self, RNN_TYPE):
self.time_step = PAST_DAY
self.epoch = 30
self.batch_size = 256
self.RNN_TYPE = RNN_TYPE
def fit(self, X_train, y_train, X_test, y_test):
self.model = self.RNNCell(self.batch_size, self.time_step, X_train.shape[2], rnn_type=self.RNN_TYPE)
callback = EarlyStopping(monitor="val_loss", patience=5, verbose=1, mode="auto")
self.history = self.model.fit(X_train, y_train,
epochs=self.epoch, batch_size=self.batch_size,
validation_data=(X_test, y_test),
callbacks=[callback])
def RNNCell(self, batch_size, time_step, input_dimension, rnn_type):
print("###########")
print(rnn_type)
print("###########")
model = Sequential()
if rnn_type == 'LSTM':
model.add(LSTM(batch_size, input_shape=(time_step, input_dimension)))
elif rnn_type == 'GRU':
model.add(GRU(batch_size, input_shape=(time_step, input_dimension)))
else:
model.add(SimpleRNN(batch_size, input_shape=(time_step, input_dimension)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.summary()
return model
def train_history_plot(self, metric):
plt.rcParams["figure.figsize"] = (12, 6)
if metric == 'acc':
val_metric = 'val_acc'
plt.ylabel('Accuracy')
title = 'Model Accuracy'
elif metric == 'loss':
val_metric = 'val_loss'
plt.ylabel('Loss')
title = 'Model Loss'
else:
pass
plt.plot(self.history.history[metric])
plt.plot(self.history.history[val_metric])
plt.title(title)
plt.xlabel('Epoch')
plt.xticks(range(0, self.epoch+1, 5))
plt.legend(['Train', 'Validation'], loc='best')
def predict_plot(self, X, y):
predict = self.model.predict(X)
loss = self.model.evaluate(X, y)
print('MSE Loss : ', loss)
plt.plot(predict, 'r--', label='Predict')
plt.plot(y, label='Real')
plt.title("Price Prediction")
plt.legend(loc='best')
plot_index = data_test.index.format(formatter=lambda x: x.strftime('%Y-%m-%d'))
plt.gca().set_xticks(range(0, len(plot_index), 30))
plt.gca().set_xticklabels(plot_index[::30])
plt.gcf().autofmt_xdate()
plt.autoscale(axis='x')
plt.show()
return loss
lstm = RNNModel('LSTM')
lstm.fit(X_train, y_train, X_test, y_test)
lstm.train_history_plot('loss')
lstm_loss = lstm.predict_plot(X_test, y_test)