In this notebook we shall present a simple conditional VAE, trained on MNIST
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import backend as K
from tensorflow.keras import metrics
from tensorflow.keras import utils
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
The variational autoencoder will allow to generate digits similar to those in the MNIST dataset.
# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
y_train = utils.to_categorical(y_train)
y_test = utils.to_categorical(y_test)
The model¶
Sampling function for the Variational Autoencoder.
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0.,
stddev=1.)
return z_mean + K.exp(z_log_var / 2) * epsilon
Main dimensions for the model (a simple stack of dense layers).
input_dim = 784
latent_dim = 16
intermediate_dim_1 = 128
intermediate_dim_2 = 32
We start with the encoder. It takes two inputs: the image and the category.
It returns the latent encoding (z_mean) and a (log-)variance for each latent variable.
x = layers.Input(shape=(input_dim,))
h1 = layers.Dense(intermediate_dim_1, activation='swish')(x)
h2 = layers.Dense(intermediate_dim_2, activation='swish')(h1)
z_mean = layers.Dense(latent_dim)(h2)
z_log_var = layers.Dense(latent_dim)(h2)
encoder = Model(x,[z_mean,z_log_var])
Now we define the decoder. It takes in input a vector in the latent space, and it returns the image of a digit.
z = layers.Input(shape=(latent_dim,))
dec_mid_1 = layers.Dense(intermediate_dim_2, activation='swish')(z)
dec_mid_2 = layers.Dense(intermediate_dim_1, activation='swish')(dec_mid_1)
x_hat = layers.Dense(input_dim,activation='sigmoid')(dec_mid_2)
decoder = Model(inputs=z, outputs=[x_hat])
We build the VAE by composing the encoder and the decoder. However, between them we need to insert the sampling operation.
In order to wrap the sampling function into a layer we use a special layer called "lambda".
x = layers.Input(shape=(input_dim,))
z_mean, z_log_var = encoder(x)
z = layers.Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
x_hat = decoder(z)
vae = Model(x,x_hat)
The VAE loss function is just the sum between the reconstruction error (mse or bce) and the KL-divergence, acting as a regularizer of the latent space.
beta = 1. #a balancing factor
rec_loss = input_dim * metrics.binary_crossentropy(x, x_hat)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(rec_loss + beta * kl_loss)
vae.add_loss(vae_loss)
Some hyperparameters. Gamma is used to balance loglikelihood and KL-divergence in the loss function
batch_size = 100
epochs = 50
We are ready to compile. There is no need to specify the loss function, since we already added it to the model with add_loss.
vae.compile(optimizer='adam')
Train for a sufficient amount of epochs. Generation is a more complex task than classification.
#vae.load_weights("cvae256_8.h5")
vae.fit(x_train, None, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(x_test, None))
vae.save_weights("cvae256_8.h5")
Epoch 1/50 600/600 [==============================] - 4s 5ms/step - loss: 188.4745 - val_loss: 150.0237 Epoch 2/50 600/600 [==============================] - 3s 5ms/step - loss: 141.7901 - val_loss: 134.6203 Epoch 3/50 600/600 [==============================] - 3s 5ms/step - loss: 131.5237 - val_loss: 127.0711 Epoch 4/50 600/600 [==============================] - 3s 5ms/step - loss: 125.6232 - val_loss: 122.0381 Epoch 5/50 600/600 [==============================] - 3s 5ms/step - loss: 121.3398 - val_loss: 118.6567 Epoch 6/50 600/600 [==============================] - 3s 5ms/step - loss: 118.7414 - val_loss: 116.5699 Epoch 7/50 600/600 [==============================] - 3s 5ms/step - loss: 116.9639 - val_loss: 115.0749 Epoch 8/50 600/600 [==============================] - 3s 5ms/step - loss: 115.4749 - val_loss: 113.8776 Epoch 9/50 600/600 [==============================] - 3s 5ms/step - loss: 114.2930 - val_loss: 112.7734 Epoch 10/50 600/600 [==============================] - 3s 5ms/step - loss: 113.3515 - val_loss: 111.9774 Epoch 11/50 600/600 [==============================] - 3s 5ms/step - loss: 112.5671 - val_loss: 111.3229 Epoch 12/50 600/600 [==============================] - 3s 5ms/step - loss: 111.9182 - val_loss: 110.8167 Epoch 13/50 600/600 [==============================] - 3s 5ms/step - loss: 111.3144 - val_loss: 110.2364 Epoch 14/50 600/600 [==============================] - 3s 5ms/step - loss: 110.7986 - val_loss: 109.8881 Epoch 15/50 600/600 [==============================] - 3s 5ms/step - loss: 110.3314 - val_loss: 109.5835 Epoch 16/50 600/600 [==============================] - 3s 5ms/step - loss: 109.9807 - val_loss: 109.3326 Epoch 17/50 600/600 [==============================] - 3s 5ms/step - loss: 109.6288 - val_loss: 108.7311 Epoch 18/50 600/600 [==============================] - 3s 5ms/step - loss: 109.3390 - val_loss: 108.5995 Epoch 19/50 600/600 [==============================] - 3s 5ms/step - loss: 109.0592 - val_loss: 108.3561 Epoch 20/50 600/600 [==============================] - 3s 5ms/step - loss: 108.8129 - val_loss: 108.1104 Epoch 21/50 600/600 [==============================] - 3s 5ms/step - loss: 108.5472 - val_loss: 107.9271 Epoch 22/50 600/600 [==============================] - 3s 5ms/step - loss: 108.3009 - val_loss: 107.8320 Epoch 23/50 600/600 [==============================] - 3s 5ms/step - loss: 108.1454 - val_loss: 107.5205 Epoch 24/50 600/600 [==============================] - 3s 5ms/step - loss: 107.9302 - val_loss: 107.3731 Epoch 25/50 600/600 [==============================] - 3s 5ms/step - loss: 107.7901 - val_loss: 107.2630 Epoch 26/50 600/600 [==============================] - 3s 5ms/step - loss: 107.5965 - val_loss: 107.1228 Epoch 27/50 600/600 [==============================] - 3s 5ms/step - loss: 107.4721 - val_loss: 106.8824 Epoch 28/50 600/600 [==============================] - 3s 5ms/step - loss: 107.3108 - val_loss: 106.9707 Epoch 29/50 600/600 [==============================] - 3s 5ms/step - loss: 107.1188 - val_loss: 106.8419 Epoch 30/50 600/600 [==============================] - 3s 5ms/step - loss: 106.9992 - val_loss: 106.6942 Epoch 31/50 600/600 [==============================] - 3s 5ms/step - loss: 106.8731 - val_loss: 106.4084 Epoch 32/50 600/600 [==============================] - 3s 5ms/step - loss: 106.7962 - val_loss: 106.4824 Epoch 33/50 600/600 [==============================] - 3s 5ms/step - loss: 106.6398 - val_loss: 106.2280 Epoch 34/50 600/600 [==============================] - 3s 5ms/step - loss: 106.5294 - val_loss: 106.0978 Epoch 35/50 600/600 [==============================] - 3s 5ms/step - loss: 106.4167 - val_loss: 106.0191 Epoch 36/50 600/600 [==============================] - 3s 5ms/step - loss: 106.3177 - val_loss: 106.2454 Epoch 37/50 600/600 [==============================] - 3s 5ms/step - loss: 106.2375 - val_loss: 105.9809 Epoch 38/50 600/600 [==============================] - 3s 5ms/step - loss: 106.1146 - val_loss: 105.7913 Epoch 39/50 600/600 [==============================] - 3s 5ms/step - loss: 106.0448 - val_loss: 105.8128 Epoch 40/50 600/600 [==============================] - 3s 5ms/step - loss: 105.9444 - val_loss: 105.6441 Epoch 41/50 600/600 [==============================] - 3s 5ms/step - loss: 105.8986 - val_loss: 105.6427 Epoch 42/50 600/600 [==============================] - 3s 5ms/step - loss: 105.8111 - val_loss: 105.7749 Epoch 43/50 600/600 [==============================] - 3s 5ms/step - loss: 105.7818 - val_loss: 105.4553 Epoch 44/50 600/600 [==============================] - 3s 5ms/step - loss: 105.6414 - val_loss: 105.4616 Epoch 45/50 600/600 [==============================] - 3s 5ms/step - loss: 105.5846 - val_loss: 105.3427 Epoch 46/50 600/600 [==============================] - 3s 5ms/step - loss: 105.5121 - val_loss: 105.5220 Epoch 47/50 600/600 [==============================] - 3s 5ms/step - loss: 105.4410 - val_loss: 105.3388 Epoch 48/50 600/600 [==============================] - 3s 5ms/step - loss: 105.4031 - val_loss: 105.1564 Epoch 49/50 600/600 [==============================] - 3s 5ms/step - loss: 105.3508 - val_loss: 105.2149 Epoch 50/50 600/600 [==============================] - 3s 5ms/step - loss: 105.2493 - val_loss: 105.2078
Let us plot some examples
def plot(images):
n = images.shape[0]
plt.figure(figsize=(2*n, 2))
for i in range(n):
# display original
ax = plt.subplot(1, n, i + 1)
plt.imshow(images[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
z_sample = np.random.normal(size=(10,latent_dim))
#print(z_sample.shape)
generated = decoder.predict(z_sample)
#print(generated.shape)
plot(generated)