LearnedVector
8/14/2018 - 5:46 AM

lstm_tf_eager.py

class LSTMModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(64)
        self.dropout = tf.layers.Dropout(rate=0.5)
        self.dense1 = tf.layers.Dense(len(np.unique(y_train)))
        self.optimizer = tf.train.AdamOptimizer()
        self.cross_entropy = tf.losses.sparse_softmax_cross_entropy

    def predict(self, input_data, train=False):

        # initial state
        batch_size = tf.shape(input_data)[0]
        state = self.rnn_cell.zero_state(batch_size, dtype=tf.float64)

        # transpose 
        # (batch_size, seq_len, n_features) => 
        # (seq_len, batch_size, n_features)
        inputs = tf.unstack(input_data, axis=1)

        outputs = []
        for sample in inputs:
            out, state = self.rnn_cell(sample, state)
            outputs.append(out)
        
        # reverse transpose
        # (seq_len, batch_size, n_features) => 
        # (batch_size, seq_len, n_features)
        out = tf.stack(outputs, axis=1)
        
        # get the output of the last time step, of each sample
        seq_len = tf.shape(out[0])[0].numpy()
        a = tf.range(batch_size)
        b = tf.constant([seq_len-1 for _ in range(batch_size.numpy())])
        idxs_last_output = tf.stack([a, b], axis=1)
        out = tf.gather_nd(out, idxs_last_output)

        out = self.dropout(out, training=train)
        logits = self.dense1(out)
        return logits

    def forward_pass(self, input_data):
        preds = self.predict(input_data, train=True)
        return preds

    def backward_pass(self, loss, tape):
        grad = tape.gradient(loss, self.variables)
        self.optimizer.apply_gradients(zip(grad, self.variables))

    def fit(self, X, Y, val_X=None, val_Y=None, epoch=1000, print_every=100):
        for i in range(epoch):
            with tfe.GradientTape() as tape:
                preds = self.forward_pass(X)
                loss = self.cross_entropy(labels=Y, logits=preds)
            self.backward_pass(loss, tape)
            if (i+1) % print_every == 0:
                loss = tf.cast(loss, tf.float64).numpy()
                acc = compute_accuracy(preds, Y)
                if val_X is not None and val_Y is not None:
                    p = self.predict(val_X)
                    val_acc = compute_accuracy(p, val_Y)
                    print("epoch", i+1, "loss", round(loss, 5), "acc", acc, "val_acc", round(val_acc, 2))
                else:
                    print("epoch", i+1, "loss", round(loss, 5), "acc", round(acc,2))

lstmmodel = LSTMModel()
lstmmodel.fit(X, Y, X_t, Y_t, epoch=1000)