In this notebook we shall present a simple conditional VAE, trained on MNIST
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import backend as K
from tensorflow.keras import metrics
from tensorflow.keras import utils
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
The conditional autoencoder will allow to generate specific digits in the MNIST range 0-9. The condition is passed as input to encoder and decoder in categorical format.
# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
y_train = utils.to_categorical(y_train)
y_test = utils.to_categorical(y_test)
The model¶
Sampling function for the Variational Autoencoder. This is the clsed form of the Kullback-Leibler distance between a gaussian N(z_mean,z_var) and a normal prior N(0,1)
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0.,
stddev=1.)
return z_mean + K.exp(z_log_var / 2) * epsilon
Main dimensions for the model (a simple stack of dense layers).
input_dim = 784
latent_dim = 8
intermediate_dim = 25
We start with the encoder. It takes two inputs: the image and the category.
It returns the latent encoding (z_mean) and a (log-)variance for each latent variable.
x = layers.Input(shape=(input_dim,))
y = layers.Input(shape=(10,))
xy = layers.concatenate([x,y])
h = layers.Dense(intermediate_dim, activation='relu')(xy)
z_mean = layers.Dense(latent_dim)(h)
z_log_var = layers.Dense(latent_dim)(h)
Now we sample around z_mean with the associated variance.
Note the use of the "lambda" layer to transform the sampling function into a keras layer.
z = layers.Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
Now we need to address the decoder. We first define its layers, in order to use them both in the vae model and in the stand-alone generator.
decoder_mid = layers.Dense(intermediate_dim, activation='relu')
decoder_out = layers.Dense(input_dim, activation='sigmoid')
We decode the image starting from the latent representation z and its category y, that must be concatenated.
zy = layers.concatenate([z,y])
dec_mid = decoder_mid(zy)
x_hat = decoder_out(dec_mid)
vae = Model(inputs=[x,y], outputs=[x_hat])
Some hyperparameters. Gamma is used to balance loglikelihood and KL-divergence in the loss function
batch_size = 100
epochs = 50
gamma = .5
The VAE loss function is just the sum between the reconstruction error (mse or bce) and the KL-divergence, acting as a regularizer of the latent space.
rec_loss = input_dim * metrics.binary_crossentropy(x, x_hat)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(rec_loss + gamma*kl_loss)
vae.add_loss(vae_loss)
We are ready to compile. There is no need to specify the loss function, since we already added it to the model with add_loss.
vae.compile(optimizer='adam')
Train for a sufficient amount of epochs. Generation is a more complex task than classification.
vae.fit([x_train,y_train], shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=([x_test,y_test], None))
vae.save_weights("cvae256_8.h5")
Epoch 1/50 600/600 [==============================] - 2s 4ms/step - loss: 117.4831 - val_loss: 116.5517 Epoch 2/50 600/600 [==============================] - 2s 4ms/step - loss: 117.4093 - val_loss: 116.6035 Epoch 3/50 600/600 [==============================] - 2s 4ms/step - loss: 117.3337 - val_loss: 116.3547 Epoch 4/50 600/600 [==============================] - 2s 4ms/step - loss: 117.2388 - val_loss: 116.1454 Epoch 5/50 600/600 [==============================] - 2s 4ms/step - loss: 117.1626 - val_loss: 116.1069 Epoch 6/50 600/600 [==============================] - 2s 4ms/step - loss: 117.0258 - val_loss: 116.0364 Epoch 7/50 600/600 [==============================] - 2s 4ms/step - loss: 116.9395 - val_loss: 116.0056 Epoch 8/50 600/600 [==============================] - 2s 4ms/step - loss: 116.8417 - val_loss: 115.9822 Epoch 9/50 600/600 [==============================] - 2s 4ms/step - loss: 116.7859 - val_loss: 115.9235 Epoch 10/50 600/600 [==============================] - 2s 4ms/step - loss: 116.6798 - val_loss: 115.8738 Epoch 11/50 600/600 [==============================] - 2s 4ms/step - loss: 116.6165 - val_loss: 115.6923 Epoch 12/50 600/600 [==============================] - 2s 4ms/step - loss: 116.5539 - val_loss: 115.6227 Epoch 13/50 600/600 [==============================] - 2s 4ms/step - loss: 116.4546 - val_loss: 115.4645 Epoch 14/50 600/600 [==============================] - 2s 4ms/step - loss: 116.3838 - val_loss: 115.4126 Epoch 15/50 600/600 [==============================] - 2s 4ms/step - loss: 116.3074 - val_loss: 115.3833 Epoch 16/50 600/600 [==============================] - 2s 4ms/step - loss: 116.2289 - val_loss: 115.3078 Epoch 17/50 600/600 [==============================] - 2s 4ms/step - loss: 116.1605 - val_loss: 115.2800 Epoch 18/50 600/600 [==============================] - 2s 4ms/step - loss: 116.1253 - val_loss: 115.1793 Epoch 19/50 600/600 [==============================] - 2s 4ms/step - loss: 116.0497 - val_loss: 115.0415 Epoch 20/50 600/600 [==============================] - 2s 4ms/step - loss: 116.0013 - val_loss: 115.0856 Epoch 21/50 600/600 [==============================] - 2s 4ms/step - loss: 115.9594 - val_loss: 115.0095 Epoch 22/50 600/600 [==============================] - 2s 4ms/step - loss: 115.8630 - val_loss: 115.0317 Epoch 23/50 600/600 [==============================] - 2s 4ms/step - loss: 115.8311 - val_loss: 114.8825 Epoch 24/50 600/600 [==============================] - 2s 4ms/step - loss: 115.7783 - val_loss: 114.8832 Epoch 25/50 600/600 [==============================] - 2s 4ms/step - loss: 115.6958 - val_loss: 114.8055 Epoch 26/50 600/600 [==============================] - 2s 4ms/step - loss: 115.6792 - val_loss: 114.7332 Epoch 27/50 600/600 [==============================] - 2s 4ms/step - loss: 115.5967 - val_loss: 114.6834 Epoch 28/50 600/600 [==============================] - 2s 4ms/step - loss: 115.5696 - val_loss: 114.6365 Epoch 29/50 600/600 [==============================] - 2s 4ms/step - loss: 115.5632 - val_loss: 114.6210 Epoch 30/50 600/600 [==============================] - 2s 4ms/step - loss: 115.4929 - val_loss: 114.6344 Epoch 31/50 600/600 [==============================] - 2s 4ms/step - loss: 115.4326 - val_loss: 114.5848 Epoch 32/50 600/600 [==============================] - 2s 4ms/step - loss: 115.4046 - val_loss: 114.5109 Epoch 33/50 600/600 [==============================] - 2s 4ms/step - loss: 115.3857 - val_loss: 114.5302 Epoch 34/50 600/600 [==============================] - 2s 4ms/step - loss: 115.3479 - val_loss: 114.5327 Epoch 35/50 600/600 [==============================] - 2s 4ms/step - loss: 115.3165 - val_loss: 114.4069 Epoch 36/50 600/600 [==============================] - 2s 4ms/step - loss: 115.2919 - val_loss: 114.3649 Epoch 37/50 600/600 [==============================] - 2s 4ms/step - loss: 115.2542 - val_loss: 114.2841 Epoch 38/50 600/600 [==============================] - 2s 4ms/step - loss: 115.2031 - val_loss: 114.2183 Epoch 39/50 600/600 [==============================] - 2s 4ms/step - loss: 115.1898 - val_loss: 114.1535 Epoch 40/50 600/600 [==============================] - 2s 4ms/step - loss: 115.1469 - val_loss: 114.2817 Epoch 41/50 600/600 [==============================] - 2s 4ms/step - loss: 115.1144 - val_loss: 114.1720 Epoch 42/50 600/600 [==============================] - 2s 4ms/step - loss: 115.0736 - val_loss: 114.1761 Epoch 43/50 600/600 [==============================] - 2s 4ms/step - loss: 115.0692 - val_loss: 114.1343 Epoch 44/50 600/600 [==============================] - 2s 4ms/step - loss: 115.0446 - val_loss: 114.0661 Epoch 45/50 600/600 [==============================] - 2s 4ms/step - loss: 115.0027 - val_loss: 113.9973 Epoch 46/50 600/600 [==============================] - 2s 4ms/step - loss: 114.9952 - val_loss: 114.1289 Epoch 47/50 600/600 [==============================] - 2s 4ms/step - loss: 114.9515 - val_loss: 113.9786 Epoch 48/50 600/600 [==============================] - 2s 4ms/step - loss: 114.9670 - val_loss: 113.8860 Epoch 49/50 600/600 [==============================] - 2s 4ms/step - loss: 114.9258 - val_loss: 113.9838 Epoch 50/50 600/600 [==============================] - 2s 4ms/step - loss: 114.8821 - val_loss: 113.9098
Let us decode the full training set.
decoded_imgs = vae.predict([x_test,y_test])
The following function is to test the quality of reconstructions (not particularly good, since compression is strong).
def plot(n=10):
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
plot()
Finally, we build a digit generator that can sample from the learned distribution
noise = layers.Input(shape=(latent_dim,))
label = layers.Input(shape=(10,))
xy = layers.concatenate([noise,label])
dec_mid = decoder_mid(xy)
dec_out = decoder_out(dec_mid)
generator = Model([noise,label],dec_out)
And we can generate our samples
import time
# display a 2D manifold of the digits
n = 8 # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
while True:
label = input("input digit to generate: ")
label = int(label)
if label < 0 or label > 9:
print(label)
break
label = np.expand_dims(utils.to_categorical(label,10),axis=0)
for i in range(0,n):
for j in range (0,n):
z_sample = np.expand_dims(np.random.normal(size=latent_dim),axis=0)
x_decoded = generator.predict([z_sample,label])
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()
time.sleep(1)
input digit to generate: 7
input digit to generate: 0