kennyhsieh1111
12/28/2019 - 3:59 PM

Recurrent Neural Network Class

Recurrent Neural Network Class

Provide three basic types of recurrent nerual network, and training history plot and prediction plot.

Input

Standard X_train, y_train, X_test, y_test

Result

  • Train History Plot
  • Prediction Plot
from keras.models import Sequential
from keras.layers import SimpleRNN, LSTM, GRU, Dense
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.models import load_model
class RNN():
    def __init__(self, RNN_TYPE):
        self.time_step = PAST_DAY
        self.epoch = 30
        self.batch_size = 256
        self.RNN_TYPE = RNN_TYPE
        
        
    def fit(self, X_train, y_train, X_test, y_test):
        self.model = self.RNNCell(self.batch_size, self.time_step, X_train.shape[2], rnn_type=self.RNN_TYPE)
        callback = EarlyStopping(monitor="val_loss", patience=5, verbose=1, mode="auto")
        self.history = self.model.fit(X_train, y_train, 
                                      epochs=self.epoch, batch_size=self.batch_size,
                                      validation_data=(X_test, y_test),
                                      callbacks=[callback])
        
        
    def RNNCell(self, batch_size, time_step, input_dimension, rnn_type):
        print("###########")
        print(rnn_type)
        print("###########")

        model = Sequential()
        if rnn_type == 'LSTM':
            model.add(LSTM(batch_size, input_shape=(time_step, input_dimension)))
        elif rnn_type == 'GRU':
            model.add(GRU(batch_size, input_shape=(time_step, input_dimension)))
        else:
            model.add(SimpleRNN(batch_size, input_shape=(time_step, input_dimension)))

        model.add(Dense(1))
        model.compile(loss='mean_squared_error', optimizer='adam')
        model.summary()
        return model
        
    def train_history_plot(self, metric):
        plt.rcParams["figure.figsize"] = (12, 6)
        if metric == 'acc':
            val_metric = 'val_acc'
            plt.ylabel('Accuracy')
            title = 'Model Accuracy'
        elif metric == 'loss':
            val_metric = 'val_loss'
            plt.ylabel('Loss')
            title = 'Model Loss'
        else:
            pass
        
        plt.plot(self.history.history[metric])
        plt.plot(self.history.history[val_metric])
        plt.title(title)
        plt.xlabel('Epoch')
        plt.xticks(range(0, self.epoch+1, 5))
        plt.legend(['Train', 'Validation'], loc='best')
        
    def predict_plot(self, X, y):
        predict = self.model.predict(X)
        loss = self.model.evaluate(X, y)
        print('MSE Loss : ', loss)

        plt.plot(predict, 'r--', label='Predict')
        plt.plot(y, label='Real')
        plt.title("Price Prediction")
        plt.legend(loc='best')
        
        plot_index = data_test.index.format(formatter=lambda x: x.strftime('%Y-%m-%d')) 
        plt.gca().set_xticks(range(0, len(plot_index), 30))
        plt.gca().set_xticklabels(plot_index[::30])
        
        plt.gcf().autofmt_xdate()
        plt.autoscale(axis='x')
        plt.show()
        
        return loss
        
lstm = RNNModel('LSTM')
lstm.fit(X_train, y_train, X_test, y_test)
lstm.train_history_plot('loss')
lstm_loss = lstm.predict_plot(X_test, y_test)