Jonathan Suru

Introduction à JAX et Flax

Depuis mes débuts dans le domaine de l'apprentissage automatique, j'ai exploré de nombreux frameworks pour développer des modèles de deep learning. Ces derniers temps, j'ai choisi de me concentrer sur deux outils en particulier : JAX et Flax. Dans cet article, je vous expliquerai pourquoi cette combinaison est devenue ma préférée, et comment elle se distingue des autres frameworks comme TensorFlow ou PyTorch.

Pourquoi JAX et FLAX ?

Les frameworks traditionnels comme PyTorch ou TensorFlow dominent le paysage du Deep Learning, mais JAX et Flax se démarquent par leur vitesse, flexibilité et contrôle fin des calculs. Conçus pour exploiter pleinement les GPU/TPU, ils sont idéaux pour la recherche et les applications gourmandes en ressources. Cet article détaille leurs fonctionnalités clés, leurs différences avec NumPy, et comment contourner leurs limites.

Qu'est-ce que JAX ?

JAX est un paquet python pour écrire des transformations numériques composables. Il tire parti d'Autograd et de XLA (Accelerated Linear Algebra) pour réaliser des calculs numériques à haute performance, ce qui est particulièrement pertinent dans le machine learning.Il peut être considéré comme une version accélérée de NumPy et est couramment utilisé avec des bibliothèques de réseaux de neurones de plus haut niveau telles que Flax.

JAX vs NumPy : Similitudes et Différences

Similitudes

JAX et NumPy permettent tous deux d'effectuer des opérations basées sur des tableaux, en utilisant une interface similaire. Cela signifie que vous pouvez exécuter bon nombre de vos opérations NumPy préférées dans JAX, en utilisant une API similaire.

Exemple:


import jax.numpy as jnp

# Addition élément par élément
jax_array1 + jax_array2

# Multiplication matricielle
jnp.dot(matrix1, matrix2)

# Fonctions mathématiques
jnp.sin(x)
    

Différences

Bien que JAX et NumPy présentent certaines similitudes, ils ont également des différences importantes. Nous aborderons ici quelques différences d'API.

1. Les tableaux Jax sont immuables.

Les tableaux JAX sont immuables, ce qui signifie que vous ne pouvez pas modifier les valeurs d'un tableau une fois qu'il a été créé. Cela diffère de NumPy, où vous pouvez modifier les valeurs d'un tableau en place.


# NumPy (modifiable)
np_array[0] = 10

# JAX (immutable)
new_array = jax_array.at[0].set(10)
        

2. Génération Aléatoire

JAX exige que vous soyez plus explicite lors de la génération de nombres aléatoires. Vous devez passer une clé (key) à chaque fois que vous appelez une fonction qui a un certain degré d'aléatoire.

Dans NumPy :


np.random.seed(0)
valeurs = np.random.rand(3)
print(valeurs)  # [0.5488135  0.71518937 0.60276338]
            

Dans JAX (avec gestion explicite des clés) :


from jax import random

key = random.PRNGKey(0)
valeurs = random.uniform(key, (3,))
print(valeurs)  # [0.5488135  0.71518937 0.60276338]

# Pour générer plusieurs ensembles
key, subkey = random.split(key)
nouvelles_valeurs = random.uniform(subkey, (3,))
            

Primitives Essentielles de JAX

jit : Compilation Juste-à-Temps

jit (Compilation Juste-à-Temps) - compile et met en cache les fonctions Python JAX afin qu'elles puissent être exécutées efficacement sur XLA pour accélérer les appels de fonctions. jit prend en entrée une fonction Python et renvoie une version compilée de cette fonction.


from jax import jit

@jit
def add(a, b):
    return a + b

result = add(jnp.array([1,2]), jnp.array([3,4]))
    

Il existe des règles concernant les types de fonctions pouvant être compilées avec jit. Par exemple, les fonctions qui contiennent un flux de contrôle Python (comme des boucles ou des conditions) peuvent ne pas être compatibles avec jit.

grad : Différentiation Automatique

grad est utilisé pour calculer automatiquement le gradient d'une fonction dans JAX. Il peut être appliqué aux fonctions Python et NumPy, ce qui signifie que vous pouvez différencier à travers des boucles, des branches, des récursions et des fermetures.


from jax import grad

def square(x):
    return x**2

gradient = grad(square)(3.0)  # Renvoie 6.0
    

vmap : Vectorisation Automatique

vmap (Vectorizing map) vectorise automatiquement vos fonctions python. Cela signifie que vous pouvez écrire une fonction qui opère sur des exemples individuels, et vmap l'appliquera automatiquement à un lot d'exemples.


from jax import vmap

batched_square = vmap(square)
result = batched_square(jnp.array([1,2,3]))  # [1, 4, 9]
    

Sans vmap, vous devriez écrire une boucle pour appliquer la fonction à chaque élément du lot. vmap vectorise automatiquement la fonction, la rendant plus efficace et concise.

JAX est différent des frameworks tels que PyTorch ou Tensorflow (TF). Il est plus bas niveau et minimaliste. JAX offre simplement un ensemble de primitives (opérations simples) comme jit et vmap, et s'appuie sur d'autres bibliothèques pour d'autres choses, par exemple en utilisant le chargeur de données de PyTorch ou TF. En raison de la simplicité de JAX, il est couramment utilisé avec des bibliothèques de réseaux de neurones de plus haut niveau telles que Haiku ou Flax.

Flax : Une API Flexible pour les Réseaux Neuraux

Flax simplifie la création et l'entraînement de modèles en combinant la puissance de JAX avec une structure orientée objet. Ses composants principaux sont :

Linen API : Définition des Couches

Les couches neurales sont définies via des modules immuables (dataclasses). Exemple :

import flax.linen as nn

class SimpleClassifier(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

Gestion des Paramètres avec Pytrees

Les paramètres (poids, biais) sont stockés dans des Pytrees, des structures arborescentes imbriquées. JAX exige une initialisation explicite des modèles :

model = SimpleClassifier()
# Initialisation avec une clé aléatoire JAX
params = model.init(jax.random.PRNGKey(0), x_sample)
# Inférence
predictions = model.apply(params, x_new)

La fonction init génère un Pytree contenant tous les paramètres, tandis que apply exécute le modèle avec ces paramètres.

Optimisation avec Optax

Optax est une bibliothèque dédiée à l'optimisation dans JAX. Elle permet de :

Optax gère les gradients et les mises à jour de manière fonctionnelle, sans effets de bord.

TrainState : Centraliser l'Entraînement

Pour simplifier l'entraînement, Flax fournit TrainState, un conteneur immutable regroupant :

from flax.training import train_state

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer
)

Boucle d'Entraînement

Exemple de fonction d'entraînement compilée avec jax.jit :

@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        preds = state.apply_fn(params, batch['inputs'])
        loss = jnp.mean((preds - batch['labels']) ** 2)
        return loss

    # Calcul des gradients et mise à jour
    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state

# Boucle d'entraînement
for epoch in range(num_epochs):
    state = train_step(state, next_batch)

Conclusion : Un duo exigeant mais gratifiant

Travailler avec JAX et Flax n’est pas toujours simple – les ressources sont encore rares, et la communauté, bien que dynamique, reste petite. Mais c’est justement ce qui rend l’aventure passionnante. En recodant des architectures PyTorch ou TensorFlow "à la main" avec ces outils, j’ai dû comprendre ce qui se cache sous la couche d’abstraction des frameworks traditionnels : gestion des clés aléatoires, vectorisation explicite, gradients manuels… Ces contraintes m’ont forcée à maîtriser des concepts souvent occultés par la simplicité de torch.nn ou tf.keras. Et finalement, cette rigueur paye : les principes de différentiation automatique, de compilation XLA, ou de gestion fonctionnelle des états sont des compétences qui dépassent largement JAX. Aujourd’hui, même quand je reviens à PyTorch, je code mieux – en comprenant *pourquoi* les choses fonctionnent, pas seulement *comment*. C’est ça, la magie de ce duo : il vous rend meilleur, même quand vous ne l’utilisez plus.

Liens Utiles

Pour approfondir vos connaissances et explorer des outils avancés, voici quelques ressources incontournables :

Ma recommandation musicale du jour : à écouter sans modération !

Écouter sur YouTube