Jonathan Suru

Implémentation d'une Vision Transformer (ViT) avec Flax et JAX pour la classification MNIST

Dans cet article, je détaille l’implémentation d’un Vision Transformer (ViT) simplifié pour la classification des chiffres manuscrits MNIST, en m’appuyant sur les bibliothèques Flax (framework de deep learning basé sur JAX) et Optax pour l’optimisation.

Les Vision Transformers (ViT) représentent une avancée majeure dans le domaine de la vision par ordinateur, en adaptant l’architecture Transformer, initialement conçue pour le traitement du langage naturel (NLP), aux tâches d’analyse d’images. Contrairement aux réseaux convolutifs (CNN) qui exploitent des filtres locaux pour capturer des motifs spatiaux, les ViT décomposent une image en patchs (morceaux) linéarisés, traités comme une séquence de tokens similaires à des mots. Ces patchs sont ensuite projetés dans un espace d’embeddings, enrichis par des informations de position (embeddings positionnels), et passés à travers un encodeur Transformer. Ce dernier, grâce à son mécanisme d’attention multi-têtes , modélise efficacement les interactions à longue portée entre les régions de l’image, permettant une compréhension globale et contextuelle.

Cette approche novatrice a été formalisée dans l’article fondateur "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2020). Les auteurs y démontrent que les Transformers, appliqués à des patchs d’images (par exemple, 16×16 pixels pour ImageNet), surpassent les CNN classiques en précision lorsque le modèle est pré-entraîné sur des jeux de données massifs (comme JFT-300M). Leur travail souligne également l’importance cruciale de l’échelle : les ViT atteignent des performances compétitives uniquement lorsqu’ils sont entraînés sur des corpus suffisamment grands, mettant en lumière les compromis entre complexité calculatoire et généralisation. Ce papier a non seulement challengé l’hégémonie des CNN en vision par ordinateur, mais a aussi ouvert la voie à une nouvelle génération de modèles hybrides (combinant CNN et Transformers) et à des applications allant de la classification à la segmentation d’images.

Bloc d’encodage Transformer avec LayerNorm

Le bloc TransformerEncoderWithLayerNorm est l’unité de base de l’encodeur Transformer. Il combine des mécanismes d’attention multi-têtes et un réseau feed-forward (MLP), le tout régularisé par des normalisations par couche (LayerNorm) et des résidus.

Structure détaillée

  1. Normalisation par couche (LayerNorm) :
    Appliquée avant chaque sous-couche (attention et MLP), la normalisation stabilise les activations en centrant et en réduisant leurs valeurs. Contrairement à BatchNorm (qui normalise sur le batch), LayerNorm opère sur la dernière dimension des features, ce qui est mieux adapté aux séquences de patchs.
  2. Attention multi-têtes :
    La couche nnx.MultiHeadAttention divise les embeddings en num_heads têtes, permettant au modèle de focaliser sur différentes parties de l’image simultanément.
    Paramètres clés :
  3. Réseau feed-forward (MLP) :
    Une séquence de deux couches linéaires (hidden_sizemlp_dimhidden_size) avec :
  4. Résidus (skip connections) :
    Les sorties de l’attention et du MLP sont additionnées aux entrées (residual connections), facilitant la propagation des gradients et l’entraînement de réseaux profonds.

Exemple de flux :

x → LayerNorm → Attention → + → LayerNorm → MLP → + → Sortie
↓___________________________↑ ↓____________________↑
    

Modèle Vision Transformer (ViT) complet

La classe VisionTransformerWithLayerNorm assemble les composants pour traiter une image et produire une prédiction.

Étapes clés

  1. Découpage en patchs et embedding :
    Une couche convolutive (patch_embeddings) découpe l’image en patchs de taille patch_size et les projette dans un espace de dimension embed_dim.
    Exemple : Une image 28x28 avec patch_size=7 génère 16 patchs (4x4), chacun converti en un vecteur de 256 dimensions.
  2. Token CLS et embeddings positionnels :
    Un token spécial [CLS] (vecteur appris) est ajouté au début de la séquence de patchs. Son embedding final servira à la classification.
    Des embeddings positionnels (aussi appris) sont ajoutés pour encoder l’ordre spatial des patchs. Ces embeddings sont initialisés avec une distribution normale tronquée (meilleure stabilité).
  3. Encodeur Transformer :
    Une séquence de num_layers blocs TransformerEncoderWithLayerNorm transforme les embeddings en représentations contextuelles.
    Le dropout est appliqué après les embeddings pour réduire le surajustement.
  4. Classification :
    Après l’encodeur, le vecteur [CLS] est isolé (x = x[:, 0]) et passé dans une couche linéaire (classifier) pour prédire les 10 classes MNIST.

Mécanisme d’inférence :

Image → Patchs → [CLS] + Patchs → Embeddings → Encodeur → [CLS] → Classification
    

Exemple concret :

Points de conception critiques

Cette architecture illustre comment les Transformers, initialement conçus pour le texte, peuvent être adaptés aux images en exploitant leur structure séquentielle et leur capacité à modéliser des dépendances à longue portée.

Résultats de l’entraînement

Après 20 epochs d’entraînement, le modèle atteint une précision de 88,13 % sur les données de test, avec une amélioration continue malgré des fluctuations. La perte d’entraînement diminue progressivement (de 1,9 à 0,35), tandis que la précision passe de 34,8 % à près de 88 %, reflétant une bonne généralisation. Bien que des oscillations apparaissent (ex. : perte remontant temporairement à 0,49 à l’epoch 19), le modèle finit par converger, démontrant l’efficacité des Transformers pour MNIST. Ces résultats pourraient encore être optimisés en ajustant le taux d’apprentissage ou le dropout pour stabiliser l’entraînement.

Pistes d’amélioration

  1. Optimisation du taux d’apprentissage :
    Utiliser un schéma de décroissance (ex. : cosine annealing) pour réduire le taux d’apprentissage progressivement et stabiliser la convergence.
  2. Augmentation des données :
    Appliquer des transformations (rotation, inversion) pour diversifier les exemples d’entraînement et améliorer la généralisation.
  3. Ajustement du dropout :
    Réduire le taux de dropout (0.1 au lieu de 0.2) pour limiter les fluctuations de perte observées en fin d’entraînement.
  4. Modèle plus profond :
    Augmenter num_layers (ex. : 8 au lieu de 6) ou embed_dim pour renforcer la capacité du réseau.
  5. Validation croisée :
    Effectuer plusieurs splits des données pour évaluer la robustesse du modèle.

Conclusion

Ce projet démontre qu’un Vision Transformer (ViT) léger, entraîné sur MNIST, atteint 88 % de précision en 20 epochs, confirmant l’adaptabilité des Transformers à des tâches simples de vision par ordinateur. Bien que les résultats soient encourageants, des ajustements (taux d’apprentissage, dropout) pourraient réduire les oscillations et améliorer la stabilité. Cette implémentation sert de base pour explorer des architectures plus complexes (ex. : Swin Transformer) ou des datasets plus exigeants (CIFAR-10, ImageNet). Les ViT, avec leur capacité à capturer des relations globales, ouvrent des perspectives passionnantes pour l’avenir de l’apprentissage visuel.

Liens Utiles

Pour approfondir vos connaissances et explorer des outils avancés, voici quelques ressources incontournables :

Ma recommandation musicale du jour : à écouter sans modération !

Écouter sur YouTube