In this notebook we shall present a simple conditional VAE, trained on MNIST

In [52]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import backend as K
from tensorflow.keras import metrics
from tensorflow.keras import utils
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt

The conditional autoencoder will allow to generate specific digits in the MNIST range 0-9. The condition is passed as input to encoder and decoder in categorical format.

In [53]:
# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
y_train = utils.to_categorical(y_train)
y_test = utils.to_categorical(y_test)

The model¶

Sampling function for the Variational Autoencoder. This is the clsed form of the Kullback-Leibler distance between a gaussian N(z_mean,z_var) and a normal prior N(0,1)

In [54]:
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0.,
                              stddev=1.)
    return z_mean + K.exp(z_log_var / 2) * epsilon

Main dimensions for the model (a simple stack of dense layers).

In [55]:
input_dim = 784
latent_dim = 8
intermediate_dim = 25

We start with the encoder. It takes two inputs: the image and the category.

It returns the latent encoding (z_mean) and a (log-)variance for each latent variable.

In [56]:
x = layers.Input(shape=(input_dim,))
y = layers.Input(shape=(10,))
xy = layers.concatenate([x,y])
h = layers.Dense(intermediate_dim, activation='relu')(xy)
z_mean = layers.Dense(latent_dim)(h)
z_log_var = layers.Dense(latent_dim)(h)

Now we sample around z_mean with the associated variance.

Note the use of the "lambda" layer to transform the sampling function into a keras layer.

In [57]:
z = layers.Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

Now we need to address the decoder. We first define its layers, in order to use them both in the vae model and in the stand-alone generator.

In [58]:
decoder_mid = layers.Dense(intermediate_dim, activation='relu')
decoder_out = layers.Dense(input_dim, activation='sigmoid')

We decode the image starting from the latent representation z and its category y, that must be concatenated.

In [59]:
zy = layers.concatenate([z,y])
dec_mid = decoder_mid(zy)
x_hat = decoder_out(dec_mid)

vae = Model(inputs=[x,y], outputs=[x_hat])

Some hyperparameters. Gamma is used to balance loglikelihood and KL-divergence in the loss function

In [91]:
batch_size = 100
epochs = 50
gamma = .5

The VAE loss function is just the sum between the reconstruction error (mse or bce) and the KL-divergence, acting as a regularizer of the latent space.

In [ ]:
rec_loss = input_dim * metrics.binary_crossentropy(x, x_hat)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(rec_loss + gamma*kl_loss)
vae.add_loss(vae_loss)

We are ready to compile. There is no need to specify the loss function, since we already added it to the model with add_loss.

In [62]:
vae.compile(optimizer='adam')

Train for a sufficient amount of epochs. Generation is a more complex task than classification.

In [92]:
vae.fit([x_train,y_train], shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=([x_test,y_test], None))

vae.save_weights("cvae256_8.h5")
Epoch 1/50
600/600 [==============================] - 2s 4ms/step - loss: 117.4831 - val_loss: 116.5517
Epoch 2/50
600/600 [==============================] - 2s 4ms/step - loss: 117.4093 - val_loss: 116.6035
Epoch 3/50
600/600 [==============================] - 2s 4ms/step - loss: 117.3337 - val_loss: 116.3547
Epoch 4/50
600/600 [==============================] - 2s 4ms/step - loss: 117.2388 - val_loss: 116.1454
Epoch 5/50
600/600 [==============================] - 2s 4ms/step - loss: 117.1626 - val_loss: 116.1069
Epoch 6/50
600/600 [==============================] - 2s 4ms/step - loss: 117.0258 - val_loss: 116.0364
Epoch 7/50
600/600 [==============================] - 2s 4ms/step - loss: 116.9395 - val_loss: 116.0056
Epoch 8/50
600/600 [==============================] - 2s 4ms/step - loss: 116.8417 - val_loss: 115.9822
Epoch 9/50
600/600 [==============================] - 2s 4ms/step - loss: 116.7859 - val_loss: 115.9235
Epoch 10/50
600/600 [==============================] - 2s 4ms/step - loss: 116.6798 - val_loss: 115.8738
Epoch 11/50
600/600 [==============================] - 2s 4ms/step - loss: 116.6165 - val_loss: 115.6923
Epoch 12/50
600/600 [==============================] - 2s 4ms/step - loss: 116.5539 - val_loss: 115.6227
Epoch 13/50
600/600 [==============================] - 2s 4ms/step - loss: 116.4546 - val_loss: 115.4645
Epoch 14/50
600/600 [==============================] - 2s 4ms/step - loss: 116.3838 - val_loss: 115.4126
Epoch 15/50
600/600 [==============================] - 2s 4ms/step - loss: 116.3074 - val_loss: 115.3833
Epoch 16/50
600/600 [==============================] - 2s 4ms/step - loss: 116.2289 - val_loss: 115.3078
Epoch 17/50
600/600 [==============================] - 2s 4ms/step - loss: 116.1605 - val_loss: 115.2800
Epoch 18/50
600/600 [==============================] - 2s 4ms/step - loss: 116.1253 - val_loss: 115.1793
Epoch 19/50
600/600 [==============================] - 2s 4ms/step - loss: 116.0497 - val_loss: 115.0415
Epoch 20/50
600/600 [==============================] - 2s 4ms/step - loss: 116.0013 - val_loss: 115.0856
Epoch 21/50
600/600 [==============================] - 2s 4ms/step - loss: 115.9594 - val_loss: 115.0095
Epoch 22/50
600/600 [==============================] - 2s 4ms/step - loss: 115.8630 - val_loss: 115.0317
Epoch 23/50
600/600 [==============================] - 2s 4ms/step - loss: 115.8311 - val_loss: 114.8825
Epoch 24/50
600/600 [==============================] - 2s 4ms/step - loss: 115.7783 - val_loss: 114.8832
Epoch 25/50
600/600 [==============================] - 2s 4ms/step - loss: 115.6958 - val_loss: 114.8055
Epoch 26/50
600/600 [==============================] - 2s 4ms/step - loss: 115.6792 - val_loss: 114.7332
Epoch 27/50
600/600 [==============================] - 2s 4ms/step - loss: 115.5967 - val_loss: 114.6834
Epoch 28/50
600/600 [==============================] - 2s 4ms/step - loss: 115.5696 - val_loss: 114.6365
Epoch 29/50
600/600 [==============================] - 2s 4ms/step - loss: 115.5632 - val_loss: 114.6210
Epoch 30/50
600/600 [==============================] - 2s 4ms/step - loss: 115.4929 - val_loss: 114.6344
Epoch 31/50
600/600 [==============================] - 2s 4ms/step - loss: 115.4326 - val_loss: 114.5848
Epoch 32/50
600/600 [==============================] - 2s 4ms/step - loss: 115.4046 - val_loss: 114.5109
Epoch 33/50
600/600 [==============================] - 2s 4ms/step - loss: 115.3857 - val_loss: 114.5302
Epoch 34/50
600/600 [==============================] - 2s 4ms/step - loss: 115.3479 - val_loss: 114.5327
Epoch 35/50
600/600 [==============================] - 2s 4ms/step - loss: 115.3165 - val_loss: 114.4069
Epoch 36/50
600/600 [==============================] - 2s 4ms/step - loss: 115.2919 - val_loss: 114.3649
Epoch 37/50
600/600 [==============================] - 2s 4ms/step - loss: 115.2542 - val_loss: 114.2841
Epoch 38/50
600/600 [==============================] - 2s 4ms/step - loss: 115.2031 - val_loss: 114.2183
Epoch 39/50
600/600 [==============================] - 2s 4ms/step - loss: 115.1898 - val_loss: 114.1535
Epoch 40/50
600/600 [==============================] - 2s 4ms/step - loss: 115.1469 - val_loss: 114.2817
Epoch 41/50
600/600 [==============================] - 2s 4ms/step - loss: 115.1144 - val_loss: 114.1720
Epoch 42/50
600/600 [==============================] - 2s 4ms/step - loss: 115.0736 - val_loss: 114.1761
Epoch 43/50
600/600 [==============================] - 2s 4ms/step - loss: 115.0692 - val_loss: 114.1343
Epoch 44/50
600/600 [==============================] - 2s 4ms/step - loss: 115.0446 - val_loss: 114.0661
Epoch 45/50
600/600 [==============================] - 2s 4ms/step - loss: 115.0027 - val_loss: 113.9973
Epoch 46/50
600/600 [==============================] - 2s 4ms/step - loss: 114.9952 - val_loss: 114.1289
Epoch 47/50
600/600 [==============================] - 2s 4ms/step - loss: 114.9515 - val_loss: 113.9786
Epoch 48/50
600/600 [==============================] - 2s 4ms/step - loss: 114.9670 - val_loss: 113.8860
Epoch 49/50
600/600 [==============================] - 2s 4ms/step - loss: 114.9258 - val_loss: 113.9838
Epoch 50/50
600/600 [==============================] - 2s 4ms/step - loss: 114.8821 - val_loss: 113.9098

Let us decode the full training set.

In [93]:
decoded_imgs = vae.predict([x_test,y_test])

The following function is to test the quality of reconstructions (not particularly good, since compression is strong).

In [96]:
def plot(n=10):
  plt.figure(figsize=(20, 4))
  for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
  plt.show()
In [94]:
plot()
No description has been provided for this image

Finally, we build a digit generator that can sample from the learned distribution

In [75]:
noise = layers.Input(shape=(latent_dim,))
label = layers.Input(shape=(10,))
xy = layers.concatenate([noise,label])
dec_mid = decoder_mid(xy)
dec_out = decoder_out(dec_mid)
generator = Model([noise,label],dec_out)

And we can generate our samples

In [ ]:
import time
# display a 2D manifold of the digits
n = 8  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))

while True:
  label = input("input digit to generate: ")
  label = int(label)
  if label < 0 or label > 9:
      print(label)
      break
  label = np.expand_dims(utils.to_categorical(label,10),axis=0)
  for i in range(0,n):
    for j in range (0,n):
        z_sample = np.expand_dims(np.random.normal(size=latent_dim),axis=0)
        x_decoded = generator.predict([z_sample,label])
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit
  plt.figure(figsize=(10, 10))
  plt.imshow(figure, cmap='Greys_r')
  plt.show()
  time.sleep(1)
input digit to generate: 7
No description has been provided for this image
input digit to generate: 0
No description has been provided for this image