Dans cet article, je détaille l’implémentation d’un Vision Transformer (ViT) simplifié pour la classification des chiffres manuscrits MNIST, en m’appuyant sur les bibliothèques Flax
(framework de deep learning basé sur JAX
) et Optax
pour l’optimisation.
Les Vision Transformers (ViT) représentent une avancée majeure dans le domaine de la vision par ordinateur, en adaptant l’architecture Transformer, initialement conçue pour le traitement du langage naturel (NLP), aux tâches d’analyse d’images. Contrairement aux réseaux convolutifs (CNN) qui exploitent des filtres locaux pour capturer des motifs spatiaux, les ViT décomposent une image en patchs (morceaux) linéarisés, traités comme une séquence de tokens similaires à des mots. Ces patchs sont ensuite projetés dans un espace d’embeddings, enrichis par des informations de position (embeddings positionnels), et passés à travers un encodeur Transformer. Ce dernier, grâce à son mécanisme d’attention multi-têtes , modélise efficacement les interactions à longue portée entre les régions de l’image, permettant une compréhension globale et contextuelle.
Cette approche novatrice a été formalisée dans l’article fondateur "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2020). Les auteurs y démontrent que les Transformers, appliqués à des patchs d’images (par exemple, 16×16 pixels pour ImageNet), surpassent les CNN classiques en précision lorsque le modèle est pré-entraîné sur des jeux de données massifs (comme JFT-300M). Leur travail souligne également l’importance cruciale de l’échelle : les ViT atteignent des performances compétitives uniquement lorsqu’ils sont entraînés sur des corpus suffisamment grands, mettant en lumière les compromis entre complexité calculatoire et généralisation. Ce papier a non seulement challengé l’hégémonie des CNN en vision par ordinateur, mais a aussi ouvert la voie à une nouvelle génération de modèles hybrides (combinant CNN et Transformers) et à des applications allant de la classification à la segmentation d’images.
Le bloc TransformerEncoderWithLayerNorm
est l’unité de base de l’encodeur Transformer. Il combine des mécanismes d’attention multi-têtes et un réseau feed-forward (MLP), le tout régularisé par des normalisations par couche (LayerNorm) et des résidus.
nnx.MultiHeadAttention
divise les embeddings en num_heads
têtes, permettant au modèle de focaliser sur différentes parties de l’image simultanément.dropout_rate=0.2
: Désactive aléatoirement 20% des connections pendant l’entraînement pour éviter le surajustement.deterministic=False
: Active le dropout uniquement en mode entraînement.hidden_size
→ mlp_dim
→ hidden_size
) avec :Exemple de flux :
x → LayerNorm → Attention → + → LayerNorm → MLP → + → Sortie ↓___________________________↑ ↓____________________↑
La classe VisionTransformerWithLayerNorm
assemble les composants pour traiter une image et produire une prédiction.
patch_embeddings
) découpe l’image en patchs de taille patch_size
et les projette dans un espace de dimension embed_dim
.patch_size=7
génère 16 patchs (4x4), chacun converti en un vecteur de 256 dimensions.[CLS]
(vecteur appris) est ajouté au début de la séquence de patchs. Son embedding final servira à la classification.num_layers
blocs TransformerEncoderWithLayerNorm
transforme les embeddings en représentations contextuelles.[CLS]
est isolé (x = x[:, 0]
) et passé dans une couche linéaire (classifier
) pour prédire les 10 classes MNIST.Mécanisme d’inférence :
Image → Patchs → [CLS] + Patchs → Embeddings → Encodeur → [CLS] → Classification
Exemple concret :
[CLS]
+ 16 patchs → 17 embeddings.[CLS]
final est transformé en logits pour les 10 classes.truncated_normal
(stddev=0.02), une heuristique éprouvée dans les Transformers.patch_size=7
, évitant un nombre excessif de tokens.Cette architecture illustre comment les Transformers, initialement conçus pour le texte, peuvent être adaptés aux images en exploitant leur structure séquentielle et leur capacité à modéliser des dépendances à longue portée.
Après 20 epochs d’entraînement, le modèle atteint une précision de 88,13 % sur les données de test, avec une amélioration continue malgré des fluctuations. La perte d’entraînement diminue progressivement (de 1,9 à 0,35), tandis que la précision passe de 34,8 % à près de 88 %, reflétant une bonne généralisation. Bien que des oscillations apparaissent (ex. : perte remontant temporairement à 0,49 à l’epoch 19), le modèle finit par converger, démontrant l’efficacité des Transformers pour MNIST. Ces résultats pourraient encore être optimisés en ajustant le taux d’apprentissage
ou le dropout
pour stabiliser l’entraînement.
0.1
au lieu de 0.2
) pour limiter les fluctuations de perte observées en fin d’entraînement.num_layers
(ex. : 8 au lieu de 6) ou embed_dim
pour renforcer la capacité du réseau.Ce projet démontre qu’un Vision Transformer (ViT) léger, entraîné sur MNIST, atteint 88 % de précision en 20 epochs, confirmant l’adaptabilité des Transformers à des tâches simples de vision par ordinateur. Bien que les résultats soient encourageants, des ajustements (taux d’apprentissage, dropout
) pourraient réduire les oscillations et améliorer la stabilité. Cette implémentation sert de base pour explorer des architectures plus complexes (ex. : Swin Transformer
) ou des datasets plus exigeants (CIFAR-10, ImageNet). Les ViT, avec leur capacité à capturer des relations globales, ouvrent des perspectives passionnantes pour l’avenir de l’apprentissage visuel.
Pour approfondir vos connaissances et explorer des outils avancés, voici quelques ressources incontournables :
Ma recommandation musicale du jour : à écouter sans modération !
Écouter sur YouTube