Jonathan Suru

Entraîner un CNN avec JAX : Un Guide Simple et Complet

Bienvenue dans ce guide pratique dédié à l'apprentissage des réseaux de neurones convolutifs (CNN) avec JAX ! Dans cet article, nous allons construire un modèle simple pour classer des images manuscrites issues du dataset MNIST. L'objectif principal est d'apprendre et de comprendre les bases des CNN ainsi que les fonctionnalités essentielles de JAX.

JAX est une bibliothèque Python idéale pour explorer le deep learning grâce à ses outils puissants comme le calcul automatique des gradients (`jax.grad`), la compilation rapide avec `@jax.jit`, et une syntaxe très proche de NumPy. Ces caractéristiques en font un excellent choix pour expérimenter et maîtriser les concepts fondamentaux.

Qu'est-ce qu'un CNN ? Une Explication Simple

Un CNN (réseau de neurones convolutifs) est un type de modèle spécialement conçu pour traiter des données structurées en grille, comme les images. Il est particulièrement efficace pour analyser des pixels et extraire des informations utiles à partir d'eux.

Pourquoi les CNN sont-ils si puissants pour ces tâches ? Explorons cela étape par étape :

Applications des CNN

Les CNN sont utilisés dans une grande variété d'applications modernes, notamment :

Les Composants Clés d'un CNN

Pour bien comprendre comment fonctionne un CNN, il est important de connaître ses composants principaux :

1. Convolution 2D (Conv2D)

La convolution est l'opération fondamentale d'un CNN, car elle permet d'extraire des caractéristiques visuelles importantes des images, comme des bords, des textures ou des formes spécifiques. Elle fonctionne en appliquant un filtre (ou noyau), qui est une petite matrice (par exemple, de taille 3x3), sur l'image pixel par pixel.


def Conv2D(params, inputs, kernel_size, stride, padding):
    kernels, biases = params
    input_height, input_width, input_channels = inputs.shape

    # Calcul des dimensions de sortie
    output_height = (input_height - kernel_size + 2 * padding) // stride + 1
    output_width = (input_width - kernel_size + 2 * padding) // stride + 1

    # Ajout de padding
    padded_input = jnp.pad(
            inputs,
            ((padding, padding), (padding, padding), (0, 0)),
            mode='constant'
    )

    # Extraction des fenêtres
    y_indices = (jnp.arange(output_height)[:, None] * stride) + jnp.arange(kernel_size)
    x_indices = (jnp.arange(output_width)[:, None] * stride) + jnp.arange(kernel_size)

    windows = padded_input[y_indices[:, None, :, None],
                           x_indices[None, :, None, :],
                           :]
    windows = windows.reshape(output_height, output_width, kernel_size, kernel_size, input_channels)

    # Calcul de la convolution
    def compute_channel(args):
        kernel, bias = args
        return jnp.sum(windows * kernel, axis=(2, 3, 4)) + bias

    output = jax.vmap(compute_channel)(kernels, biases)
    return jnp.transpose(output, (1, 2, 0))
    

2. Pooling

Le pooling est une étape cruciale dans un CNN, car elle permet de réduire la taille des données tout en conservant les informations essentielles extraites par les couches de convolution. Cette opération diminue non seulement la complexité du modèle, mais aussi la sensibilité aux petites variations dans les images, comme des décalages mineurs.


def avg_pool(input_data, window_shape=(2, 2), strides=(2, 2)):
    input_height, input_width, num_channels = input_data.shape
    window_height, window_width = window_shape
    stride_height, stride_width = strides

    output_height = (input_height - window_height) // stride_height + 1
    output_width = (input_width - window_width) // stride_width + 1

    output = jnp.zeros((output_height, output_width, num_channels))

    for y in range(output_height):
        for x in range(output_width):
            y_start = y * stride_height
            y_end = y_start + window_height
            x_start = x * stride_width
            x_end = x_start + window_width

            window = input_data[y_start:y_end, x_start:x_end, :]
            output = output.at[y, x, :].set(jnp.mean(window, axis=(0, 1)))

    return output
    

3. Couches Denses (Fully Connected Layers)

Après avoir extrait des caractéristiques grâce aux couches convolutives et au pooling, les couches denses jouent un rôle crucial dans la transformation de ces caractéristiques en prédictions finales. Ces couches connectent chaque neurone de la couche précédente à tous les neurones de la couche suivante.


def dense_layer(params, inputs):
    weights, biases = params
    return jnp.dot(inputs, weights) + biases
    

Architecture du Modèle

Nous allons créer un CNN simple composé des éléments suivants :

Il est important de noter que JAX est un cadre entièrement fonctionnel. Cela signifie que les paramètres du modèle sont traités comme un ensemble distinct de nombres, existant « en dehors » du modèle lui-même. Cette approche permet une grande flexibilité dans la conception et l'optimisation des modèles.


import jax
import jax.numpy as jnp
from jax import random

def initialize_Conv2D(key, input_channels, kernel_size, num_kernels):
    w_key, b_key = random.split(key)
    weight_shape = (kernel_size, kernel_size, input_channels, num_kernels)
    bias_shape = (num_kernels,)
    weights = random.normal(w_key, weight_shape) * 0.01
    biases = random.normal(b_key, bias_shape) * 0.01
    return weights, biases

def initialize_dense_layer(key, input_dim, output_dim):
    w_key, b_key = random.split(key)
    weights = random.normal(w_key, (input_dim, output_dim)) * 0.01
    biases = random.normal(b_key, (output_dim,)) * 0.01
    return weights, biases

# Initialisation des paramètres du CNN
cnn_parameters = {
    'conv1': None,
    'conv2': None,
    'dense1': None,
    'dense2': None,
}

key = random.PRNGKey(42)  # Clé aléatoire initiale
key, *subkeys = random.split(key, 5)

# Initialisation des couches Conv2D
cnn_parameters['conv1'] = initialize_Conv2D(
    subkeys[0], input_channels=1, kernel_size=3, num_kernels=32
)
cnn_parameters['conv2'] = initialize_Conv2D(
    subkeys[1], input_channels=32, kernel_size=3, num_kernels=64
)

# Initialisation des couches Dense
cnn_parameters['dense1'] = initialize_dense_layer(
    subkeys[2], input_dim=3136, output_dim=256
)
cnn_parameters['dense2'] = initialize_dense_layer(
    subkeys[3], input_dim=256, output_dim=10
)
    

Implémentation du Modèle

Voici comment assembler les différentes parties du modèle :


def cnn(inputs, cnn_parameters):
    x = jnp.transpose(inputs, (1, 2, 0))  # Conversion au format (H, W, C)

    # Première couche Conv + Pool
    x = Conv2D(cnn_parameters['conv1'], x, kernel_size=3, stride=1, padding=1)
    x = avg_pool(x, window_shape=(2, 2), strides=(2, 2))

    # Seconde couche Conv + Pool
    x = Conv2D(cnn_parameters['conv2'], x, kernel_size=3, stride=1, padding=1)
    x = avg_pool(x, window_shape=(2, 2), strides=(2, 2))

    # Aplatir
    x = jnp.reshape(x, (-1,))

    # Couches Dense
    x = dense_layer(cnn_parameters['dense1'], x)
    x = jax.nn.relu(x)  # Activation ReLU
    x = dense_layer(cnn_parameters['dense2'], x)

    return x
    

Fonction de Perte

Pour entraîner un CNN, il est essentiel de mesurer l'écart entre les prédictions du modèle et les vraies valeurs. Cela se fait via une fonction de perte. Dans notre cas, nous utilisons la fonction de perte d'entropie croisée (cross-entropy loss), qui est couramment utilisée pour les problèmes de classification.


def cross_entropy_loss(params, inputs, targets):
    preds = cnn(inputs, params)
    one_hot_targets = jax.nn.one_hot(targets, 10)
    loss = -jnp.sum(one_hot_targets * jax.nn.log_softmax(preds))
    return loss
    

Rétropropagation

La rétropropagation est une technique fondamentale qui permet de calculer les gradients de la fonction de perte par rapport à chaque paramètre du modèle. Ces gradients indiquent dans quelle direction modifier les poids pour réduire l'erreur. En d'autres termes, ils guident le modèle vers une meilleure performance.

Grâce à JAX, nous pouvons calculer facilement les gradients en utilisant la fonction `jax.value_and_grad`.


loss, grads = jax.value_and_grad(cross_entropy_loss)(params, inputs, targets)
    

Optimisation

L'optimisation consiste à mettre à jour les paramètres du modèle en fonction des gradients calculés lors de la rétropropagation. La méthode la plus courante est la descente de gradient, qui ajuste chaque paramètre en fonction de son gradient et d'un taux d'apprentissage (`learning_rate`). Voici comment cette mise à jour est effectuée dans notre code :


learning_rate = 0.001

@jax.jit
def update_params(params, inputs, targets):
    loss, grads = jax.value_and_grad(cross_entropy_loss)(params, inputs, targets)
    updated_params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
    return updated_params, loss
    

Le taux d'apprentissage (`learning_rate`) contrôle la vitesse à laquelle les paramètres sont mis à jour. Un taux trop élevé peut rendre l'entraînement instable, tandis qu'un taux trop faible peut ralentir la convergence.

Importance de la Fonction de Perte, de la Rétropropagation et de l'Optimisation

La fonction de perte, la rétropropagation et l'optimisation sont des étapes cruciales dans l'entraînement d'un modèle CNN :

Ces trois étapes travaillent ensemble pour entraîner le modèle et l'amener à généraliser correctement sur de nouvelles données.

Entraînement du Modèle

L'étape d'entraînement est cruciale pour ajuster les paramètres du modèle afin de minimiser la fonction de perte et d'améliorer ses performances. Voici comment ce processus est mis en œuvre :


@jax.jit
def train_step(patches, cnn_parameters, target):
    # Compute gradients
    current_loss, grads = jax.value_and_grad(cross_entropy_loss, argnums=1)(
            patches,
            cnn_parameters,
            target)
    # Update parameters
    updated_params = jax.tree_map(lambda p, g: p - 0.01 * g, cnn_parameters, grads)
    return current_loss, updated_params

num_epochs = 20

for epoch in range(num_epochs):
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    for i, (data, target) in progress_bar:
        # Convert to numpy
        data = jnp.asarray(data.numpy(), dtype=jnp.float32)
        target = jnp.asarray(target.numpy(), dtype=jnp.float32)
        # Reshape and get one hot for loss
        target_one_hot = jax.nn.one_hot(target, num_classes)
        current_loss, cnn_parameters = train_step(data, cnn_parameters, target_one_hot)
        progress_bar.set_postfix({'loss': current_loss})

    eval_acc = eval(cnn_parameters)
    print(f'Epoch: {epoch}, Eval acc: {eval_acc}')
    

Le modèle présenté dans cet article a été entraîné sur un TPU fourni par Google Colab, permettant de terminer l'entraînement en seulement 20 minutes. Cependant, les performances obtenues restent modestes, avec une précision d'environ 3%, ce qui est extrêmement faible pour une tâche de classification comme MNIST. Cette limitation s'explique par le caractère éducatif du modèle, qui n'a pas été optimisé pour des performances industrielles.

Amélioration du Modèle

Bien que le modèle CNN présenté ici soit conçu à des fins éducatives, il existe plusieurs façons d'améliorer ses performances et de le rendre plus adapté à des applications réelles :

Conclusion

Vous avez maintenant tous les outils nécessaires pour créer et entraîner un CNN en utilisant uniquement JAX. Ce guide vous a permis de comprendre les étapes clés telles que la convolution, le pooling, et l'optimisation des paramètres.

Pour aller plus loin, n'hésitez pas à expérimenter avec des architectures plus complexes ou à tester votre modèle sur des datasets plus variés comme CIFAR-10 ou ImageNet. Bonne continuation dans votre parcours en deep learning !

Je tiens à remercier chaleureusement Alessio Devoto, dont le travail sur l'implémentation de modèles en pure JAX a été une source d'inspiration précieuse pour cet article. De plus, le code utilisé dans cet article est disponible dans ce dépôt GitHub : mnist_jax.ipynb. N'hésitez pas à consulter ces ressources pour approfondir vos connaissances et découvrir des implémentations supplémentaires. Merci également à vous, lecteur(ice), pour votre curiosité et votre engagement à apprendre. Bonne continuation dans votre parcours en deep learning !

Liens Utiles

Pour approfondir vos connaissances et explorer des outils avancés, voici quelques ressources incontournables :

Ma recommandation musicale du jour : à écouter sans modération !

Écouter sur YouTube