Summary: Recurrent Neural Networks, RNN, LSTM, Long Short-term Memory, seq2seq

Implementation

generate_data[source]

generate_data(training_size=10)

encode[source]

encode(questions, answers, alphabet)

decode[source]

decode(seq, alphabet, calc_argmax=True)

Let's generate some data

DIGITS = 3
MAXLEN = DIGITS + DIGITS + 1
n_training_examples = 1000
print('Generating data...', end=' ')
pairs,ans = generate_data(n_training_examples)
print('done!')
print('Size of Training set: ' , len(pairs))
alphabet = list('0123456789+ ')
x,y = encode(pairs, ans, alphabet)
100%|██████████| 1000/1000 [00:00<00:00, 11661.15it/s]
Generating data... done!
Size of Training set:  1000

Split the data

We split the data into training and testting sets.

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1)
print('x_train shape = ' , x_train.shape)
print('y_train shape = ', y_train.shape)
x_train shape =  (900, 7, 12)
y_train shape =  (900, 4, 12)

Build the Model

Now it's time to build an RNN with LSTM cells.

print('Build model...')
model = Sequential()
model.add(LSTM(128, input_shape=(MAXLEN, len(alphabet))))
model.add(RepeatVector(DIGITS + 1))
model.add(LSTM(128, return_sequences=True))
model.add(TimeDistributed(Dense(len(alphabet), activation='softmax')))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
Build model...
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm (LSTM)                  (None, 128)               72192     
_________________________________________________________________
repeat_vector (RepeatVector) (None, 4, 128)            0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 4, 128)            131584    
_________________________________________________________________
time_distributed (TimeDistri (None, 4, 12)             1548      
=================================================================
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
_________________________________________________________________

Train the Model

After we builded and compiled the model, we must train it.

EPOCHS = 2
BATCH_SIZE = 32

class colors:
    ok = '\033[92m'
    fail = '\033[91m'
    close = '\033[0m'

for epoch in range(1, EPOCHS + 1):
    print('Iteration ', epoch)
    model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=1, validation_data=(x_test, y_test), verbose=1)
    # Select 10 samples from test set and visualize errors
    for i in range(10):
        index = np.random.randint(0, len(x_test))
        q = x_test[np.array([index])]
        ans = y_test[np.array([index])]
        preds = np.argmax(model.predict(q),axis=-1)
        question = decode(q[0],alphabet)
        actual = decode(ans[0],alphabet)
        guessed = decode(preds[0], alphabet, calc_argmax=False)
        print('Q:', question, end=' ')
        print('  Actual:', actual, end=' ')
        if actual == guessed:
            print(colors.ok + '  ☑' + colors.close, end=' ')
        else:
            print(colors.fail + '  ☒' + colors.close, end=' ')
        print('Guessed:', guessed)
Iteration  1
Train on 900 samples, validate on 100 samples
900/900 [==============================] - 1s 1ms/sample - loss: 1.5905 - accuracy: 0.3958 - val_loss: 1.6040 - val_accuracy: 0.3925
Q:   5+340   Actual:  345  Guessed:  155
Q:   5+721   Actual:  726  Guessed:  155
Q:    27+7   Actual:   34  Guessed:   22
Q:  91+665   Actual:  756  Guessed:  155
Q:   388+9   Actual:  397  Guessed:  155
Q:  941+23   Actual:  964  Guessed:  155
Q:   17+67   Actual:   84  Guessed:   55
Q:  91+665   Actual:  756  Guessed:  155
Q: 204+890   Actual: 1094  Guessed:  555
Q:     5+9   Actual:   14  Guessed:    2
Iteration  2
Train on 900 samples, validate on 100 samples
900/900 [==============================] - 1s 1ms/sample - loss: 1.5800 - accuracy: 0.4006 - val_loss: 1.6091 - val_accuracy: 0.4025
Q:   7+471   Actual:  478  Guessed:  449
Q:  672+52   Actual:  724  Guessed:  499
Q:    51+1   Actual:   52  Guessed:   34
Q: 757+459   Actual: 1216  Guessed: 1444
Q:    51+1   Actual:   52  Guessed:   34
Q:   5+721   Actual:  726  Guessed:  149
Q:     9+2   Actual:   11  Guessed:   34
Q:     5+9   Actual:   14  Guessed:   14
Q:  547+73   Actual:  620  Guessed:  499
Q:     4+4   Actual:    8  Guessed:   14