Bienvenue dans ce guide pratique dédié à l'apprentissage des réseaux de neurones convolutifs (CNN) avec JAX ! Dans cet article, nous allons construire un modèle simple pour classer des images manuscrites issues du dataset MNIST. L'objectif principal est d'apprendre et de comprendre les bases des CNN ainsi que les fonctionnalités essentielles de JAX.
JAX est une bibliothèque Python idéale pour explorer le deep learning grâce à ses outils puissants comme le calcul automatique des gradients (`jax.grad`), la compilation rapide avec `@jax.jit`, et une syntaxe très proche de NumPy. Ces caractéristiques en font un excellent choix pour expérimenter et maîtriser les concepts fondamentaux.
Un CNN (réseau de neurones convolutifs) est un type de modèle spécialement conçu pour traiter des données structurées en grille, comme les images. Il est particulièrement efficace pour analyser des pixels et extraire des informations utiles à partir d'eux.
Pourquoi les CNN sont-ils si puissants pour ces tâches ? Explorons cela étape par étape :
Les CNN sont utilisés dans une grande variété d'applications modernes, notamment :
Pour bien comprendre comment fonctionne un CNN, il est important de connaître ses composants principaux :
La convolution est l'opération fondamentale d'un CNN, car elle permet d'extraire des caractéristiques visuelles importantes des images, comme des bords, des textures ou des formes spécifiques. Elle fonctionne en appliquant un filtre (ou noyau), qui est une petite matrice (par exemple, de taille 3x3), sur l'image pixel par pixel.
def Conv2D(params, inputs, kernel_size, stride, padding):
kernels, biases = params
input_height, input_width, input_channels = inputs.shape
# Calcul des dimensions de sortie
output_height = (input_height - kernel_size + 2 * padding) // stride + 1
output_width = (input_width - kernel_size + 2 * padding) // stride + 1
# Ajout de padding
padded_input = jnp.pad(
inputs,
((padding, padding), (padding, padding), (0, 0)),
mode='constant'
)
# Extraction des fenêtres
y_indices = (jnp.arange(output_height)[:, None] * stride) + jnp.arange(kernel_size)
x_indices = (jnp.arange(output_width)[:, None] * stride) + jnp.arange(kernel_size)
windows = padded_input[y_indices[:, None, :, None],
x_indices[None, :, None, :],
:]
windows = windows.reshape(output_height, output_width, kernel_size, kernel_size, input_channels)
# Calcul de la convolution
def compute_channel(args):
kernel, bias = args
return jnp.sum(windows * kernel, axis=(2, 3, 4)) + bias
output = jax.vmap(compute_channel)(kernels, biases)
return jnp.transpose(output, (1, 2, 0))
Le pooling est une étape cruciale dans un CNN, car elle permet de réduire la taille des données tout en conservant les informations essentielles extraites par les couches de convolution. Cette opération diminue non seulement la complexité du modèle, mais aussi la sensibilité aux petites variations dans les images, comme des décalages mineurs.
def avg_pool(input_data, window_shape=(2, 2), strides=(2, 2)):
input_height, input_width, num_channels = input_data.shape
window_height, window_width = window_shape
stride_height, stride_width = strides
output_height = (input_height - window_height) // stride_height + 1
output_width = (input_width - window_width) // stride_width + 1
output = jnp.zeros((output_height, output_width, num_channels))
for y in range(output_height):
for x in range(output_width):
y_start = y * stride_height
y_end = y_start + window_height
x_start = x * stride_width
x_end = x_start + window_width
window = input_data[y_start:y_end, x_start:x_end, :]
output = output.at[y, x, :].set(jnp.mean(window, axis=(0, 1)))
return output
Après avoir extrait des caractéristiques grâce aux couches convolutives et au pooling, les couches denses jouent un rôle crucial dans la transformation de ces caractéristiques en prédictions finales. Ces couches connectent chaque neurone de la couche précédente à tous les neurones de la couche suivante.
def dense_layer(params, inputs):
weights, biases = params
return jnp.dot(inputs, weights) + biases
Nous allons créer un CNN simple composé des éléments suivants :
Il est important de noter que JAX est un cadre entièrement fonctionnel. Cela signifie que les paramètres du modèle sont traités comme un ensemble distinct de nombres, existant « en dehors » du modèle lui-même. Cette approche permet une grande flexibilité dans la conception et l'optimisation des modèles.
import jax
import jax.numpy as jnp
from jax import random
def initialize_Conv2D(key, input_channels, kernel_size, num_kernels):
w_key, b_key = random.split(key)
weight_shape = (kernel_size, kernel_size, input_channels, num_kernels)
bias_shape = (num_kernels,)
weights = random.normal(w_key, weight_shape) * 0.01
biases = random.normal(b_key, bias_shape) * 0.01
return weights, biases
def initialize_dense_layer(key, input_dim, output_dim):
w_key, b_key = random.split(key)
weights = random.normal(w_key, (input_dim, output_dim)) * 0.01
biases = random.normal(b_key, (output_dim,)) * 0.01
return weights, biases
# Initialisation des paramètres du CNN
cnn_parameters = {
'conv1': None,
'conv2': None,
'dense1': None,
'dense2': None,
}
key = random.PRNGKey(42) # Clé aléatoire initiale
key, *subkeys = random.split(key, 5)
# Initialisation des couches Conv2D
cnn_parameters['conv1'] = initialize_Conv2D(
subkeys[0], input_channels=1, kernel_size=3, num_kernels=32
)
cnn_parameters['conv2'] = initialize_Conv2D(
subkeys[1], input_channels=32, kernel_size=3, num_kernels=64
)
# Initialisation des couches Dense
cnn_parameters['dense1'] = initialize_dense_layer(
subkeys[2], input_dim=3136, output_dim=256
)
cnn_parameters['dense2'] = initialize_dense_layer(
subkeys[3], input_dim=256, output_dim=10
)
Voici comment assembler les différentes parties du modèle :
def cnn(inputs, cnn_parameters):
x = jnp.transpose(inputs, (1, 2, 0)) # Conversion au format (H, W, C)
# Première couche Conv + Pool
x = Conv2D(cnn_parameters['conv1'], x, kernel_size=3, stride=1, padding=1)
x = avg_pool(x, window_shape=(2, 2), strides=(2, 2))
# Seconde couche Conv + Pool
x = Conv2D(cnn_parameters['conv2'], x, kernel_size=3, stride=1, padding=1)
x = avg_pool(x, window_shape=(2, 2), strides=(2, 2))
# Aplatir
x = jnp.reshape(x, (-1,))
# Couches Dense
x = dense_layer(cnn_parameters['dense1'], x)
x = jax.nn.relu(x) # Activation ReLU
x = dense_layer(cnn_parameters['dense2'], x)
return x
Pour entraîner un CNN, il est essentiel de mesurer l'écart entre les prédictions du modèle et les vraies valeurs. Cela se fait via une fonction de perte. Dans notre cas, nous utilisons la fonction de perte d'entropie croisée (cross-entropy loss), qui est couramment utilisée pour les problèmes de classification.
def cross_entropy_loss(params, inputs, targets):
preds = cnn(inputs, params)
one_hot_targets = jax.nn.one_hot(targets, 10)
loss = -jnp.sum(one_hot_targets * jax.nn.log_softmax(preds))
return loss
La rétropropagation est une technique fondamentale qui permet de calculer les gradients de la fonction de perte par rapport à chaque paramètre du modèle. Ces gradients indiquent dans quelle direction modifier les poids pour réduire l'erreur. En d'autres termes, ils guident le modèle vers une meilleure performance.
Grâce à JAX, nous pouvons calculer facilement les gradients en utilisant la fonction `jax.value_and_grad`.
loss, grads = jax.value_and_grad(cross_entropy_loss)(params, inputs, targets)
L'optimisation consiste à mettre à jour les paramètres du modèle en fonction des gradients calculés lors de la rétropropagation. La méthode la plus courante est la descente de gradient, qui ajuste chaque paramètre en fonction de son gradient et d'un taux d'apprentissage (`learning_rate`). Voici comment cette mise à jour est effectuée dans notre code :
learning_rate = 0.001
@jax.jit
def update_params(params, inputs, targets):
loss, grads = jax.value_and_grad(cross_entropy_loss)(params, inputs, targets)
updated_params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
return updated_params, loss
Le taux d'apprentissage (`learning_rate`) contrôle la vitesse à laquelle les paramètres sont mis à jour. Un taux trop élevé peut rendre l'entraînement instable, tandis qu'un taux trop faible peut ralentir la convergence.
La fonction de perte, la rétropropagation et l'optimisation sont des étapes cruciales dans l'entraînement d'un modèle CNN :
Ces trois étapes travaillent ensemble pour entraîner le modèle et l'amener à généraliser correctement sur de nouvelles données.
L'étape d'entraînement est cruciale pour ajuster les paramètres du modèle afin de minimiser la fonction de perte et d'améliorer ses performances. Voici comment ce processus est mis en œuvre :
@jax.jit
def train_step(patches, cnn_parameters, target):
# Compute gradients
current_loss, grads = jax.value_and_grad(cross_entropy_loss, argnums=1)(
patches,
cnn_parameters,
target)
# Update parameters
updated_params = jax.tree_map(lambda p, g: p - 0.01 * g, cnn_parameters, grads)
return current_loss, updated_params
num_epochs = 20
for epoch in range(num_epochs):
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
for i, (data, target) in progress_bar:
# Convert to numpy
data = jnp.asarray(data.numpy(), dtype=jnp.float32)
target = jnp.asarray(target.numpy(), dtype=jnp.float32)
# Reshape and get one hot for loss
target_one_hot = jax.nn.one_hot(target, num_classes)
current_loss, cnn_parameters = train_step(data, cnn_parameters, target_one_hot)
progress_bar.set_postfix({'loss': current_loss})
eval_acc = eval(cnn_parameters)
print(f'Epoch: {epoch}, Eval acc: {eval_acc}')
Le modèle présenté dans cet article a été entraîné sur un TPU fourni par Google Colab, permettant de terminer l'entraînement en seulement 20 minutes. Cependant, les performances obtenues restent modestes, avec une précision d'environ 3%, ce qui est extrêmement faible pour une tâche de classification comme MNIST. Cette limitation s'explique par le caractère éducatif du modèle, qui n'a pas été optimisé pour des performances industrielles.
Bien que le modèle CNN présenté ici soit conçu à des fins éducatives, il existe plusieurs façons d'améliorer ses performances et de le rendre plus adapté à des applications réelles :
Vous avez maintenant tous les outils nécessaires pour créer et entraîner un CNN en utilisant uniquement JAX. Ce guide vous a permis de comprendre les étapes clés telles que la convolution, le pooling, et l'optimisation des paramètres.
Pour aller plus loin, n'hésitez pas à expérimenter avec des architectures plus complexes ou à tester votre modèle sur des datasets plus variés comme CIFAR-10 ou ImageNet. Bonne continuation dans votre parcours en deep learning !
Je tiens à remercier chaleureusement Alessio Devoto, dont le travail sur l'implémentation de modèles en pure JAX a été une source d'inspiration précieuse pour cet article. De plus, le code utilisé dans cet article est disponible dans ce dépôt GitHub : mnist_jax.ipynb. N'hésitez pas à consulter ces ressources pour approfondir vos connaissances et découvrir des implémentations supplémentaires. Merci également à vous, lecteur(ice), pour votre curiosité et votre engagement à apprendre. Bonne continuation dans votre parcours en deep learning !
Pour approfondir vos connaissances et explorer des outils avancés, voici quelques ressources incontournables :
Ma recommandation musicale du jour : à écouter sans modération !
Écouter sur YouTube