Dans cette partie, nous cherchons à générer des images de chiffres manuscrits à l'aide d'un réseau DCGAN ( Deep Convolutional Generative Adversarial Network).
Le code est écrit à l'aide de l'API séquentielle Keras avec une boucle d'entraînement tf.GradientTape, une méthode d'enregistrement d'opérations pour la différenciation automatique [2]. L'API Keras, quant à elle, permet de créer et d'entraîner des modèles de Deep Learning facilement et rapidement [3].
Nous allons essayer d'illustrer le processus d'apprentissage des GANs avec le jeu de données MNIST. Il s'agit d'une base de données de chiffres écrits à la main. L'animation suivante montre une série d'images produites par le générateur lors de son apprentissage sur 50 époques. Les images commencent par un bruit aléatoire et ressemblent de plus en plus à des chiffres écrits à la main au fil du temps.
Figure 1 – Animation montrant une série d'images produites par le générateur à l'aide du jeu de données MNIST [1]
import tensorflow as tf
import tensorflow_datasets as tfds
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
Nous utiliserons le jeu de données MNIST pour former le générateur et le discriminateur. Le générateur générera des chiffres manuscrits ressemblant aux données MNIST.
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Regroupe et mélange les données
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Le générateur et le discriminateur sont définis à l'aide de l' API séquentielle Keras [3].
Le générateur utilise tf.keras.layers.Conv2DTranspose (suréchantillonage) pour produire une image à partir d'une graine (bruit aléatoire) [4].
Nous commençons avec un calque Dense prenant cette graine en entrée, puis nous suréchantillonnons plusieurs fois jusqu'à atteindre une taille d'image raisonnable, soit de 28x28x1. A noter que nous activons tf.keras.layers.LeakyReLU [5] pour chaque couche, à l'exception de la couche de sortie qui utilise tanh.
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # Note: None représente la taille de groupement
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
return model
Nous utilisons alors le générateur (pas encore formé) pour créer une image.
generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
Figure 1 – Première sortie du générateur non entraîné
Le discriminateur est un classificateur d'image basé sur CNN [6].
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
Nous utilisons alors le discriminateur (pas encore formé) pour classer les images générées comme réelles ou fausses. Le modèle sera formé pour générer des valeurs positives pour les images réelles et des valeurs négatives pour les images factices.
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print(decision)
Nous obtenons deux résultats:
tf.Tensor([[-0.00090226]], shape=(1, 1), dtype=float32)
tf.Tensor([[0.00230907]], shape=(1, 1), dtype=float32)
Nous constatons bien que le discriminateur n'est pas entraîné étant donné que deux décisions faites l'une après l'autre donnent des résultats différents concernant la même image. En effet, le premier résultat donne une image factice, tandis que le deuxième donne une image authentique.
Définissons les fonctions de perte et les optimiseurs pour les deux modèles.
# Cette méthode renvoie une fonction permettant de calculer le cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
Cette méthode quantifie la capacité du discriminateur à distinguer les vraies images des fausses. Il compare les prédictions du discriminateur sur des images réelles à un tableau de 1, et les prédictions du discriminateur sur de fausses images (générées) à un tableau de 0.
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
La perte du générateur quantifie à quel point il a réussi à tromper le discriminateur. Intuitivement, si le générateur fonctionne bien, le discriminateur classera les fausses images comme vraies (ou 1). Ici, nous comparons les décisions du discriminateur sur les images générées à un tableau de 1.
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
Le discriminateur et les optimiseurs de générateur sont différents puisque nous entraînons deux réseaux séparément.
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
# Nous réutiliserons cette graine pour simplifier l'usage
seed = tf.random.normal([num_examples_to_generate, noise_dim])
La boucle d'apprentissage commence avec le générateur recevant une graine aléatoire en entrée. Cette graine est utilisée pour produire une image. Le discriminateur est ensuite utilisé pour classer les images réelles (tirées de l'ensemble d'apprentissage) et les fausses images (produites par le générateur). La perte est calculée pour chacun de ces modèles, et les gradients sont utilisés pour mettre à jour le générateur et le discriminateur.
# A noter l'utilisation de `tf.function`
# Cette annotation provoque la "compilation" de la fonction
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
Il ne nous reste plus qu'à appeler la méthode train() définie ci-dessus pour entraîner simultanément le générateur et le discriminateur. A savoir que la formation des GAN peut être délicate. Il est important que le générateur et le discriminateur ne se surpassent pas (par exemple, qu'ils s'entraînent à un rythme similaire).
Au début de la formation, les images générées ressemblent à du bruit aléatoire. Au fur et à mesure que la formation progresse, les chiffres générés sembleront de plus en plus réels. Après environ 30 époques, ils ressemblent aux chiffres MNIST.
train(train_dataset, EPOCHS)
Lors des premières générations, nous constatons que les images ne ressemblent pas aux chiffres MNIST. Nous voyons bien que le GAN commence à apprendre le dataset.
Figure 2 – Epoch 0
Passé une dizaine de générations, nous voyons des formes distinctes apparaître. Nous sommes tout de même loin de distinguer des chiffres au sein de nos générations.
Figure 3 – Epoch 10
Après cela, au fil des générations, nous commençons à distinguer des chiffres. Il faut alors faire un peu plus de quarante générations pour distinguer des "7" ou bien des "9".
Figure 4 – Epoch 40
Finalement, nous voyons bien les chiffres MNIST passé une centaine de générations. Cependant, nous remarquons que tous les chiffres n'apparaissent pas, laissant présager un mode collapse.
Figure 5 – Epoch 100
Après cela, les chiffres se clarifient davantage, que ce soit après cinq cent ou mille générations, le générateur se précise sans cesse. Nous pouvons le voir notamment avec le "5" que le générateur parvient à produire clairement avec le millième échantillon.
Figure 6 – Epoch 500
Figure 7 – Epoch 1000
Notre programme est concluant et permet de réaliser certains chiffres du dataset de façon nets et visibles. Il permet également d'afficher et d'enregistrer chaque image pour avoir un retour direct sur son efficacité. Le réseau s'améliore bien comme prévu, que ce soit du côté du discriminateur ou du générateur. Tout ceci est notamment possible grâce aux bibliothèques tensorflow et matplotlib mettant à disposition de nombreuses fonctionnalités et documentations.
Nous voyons lors des générations que les capacités de calcul déployés ne sont pas suffisantes. En effet, pour chaque génération, mon ordianteur personnel met approximativement 400 secondes, soit plus de 6 minutes. De plus, quelques générations prennent plus de temps, ce qui peut limiter la génération avec notre type de GAN. Je remercie notamment Moiroud Elliott, sans qui je ne serais pas parvenu à produire autant de générations.
L'objectif ici était de réaliser notre propre GAN capable de s'améliorer lui-même, chose que nous avons pu constater lors de ses diverses générations. Notre modèle est évidemment perfectible et laisse supposer certaines erreurs d'entraînement, comme le mode collapse. Cependant, le but n'était pas de réaliser le GAN le plus optimisé possible, seulement en concevoir un simple exemplaire.