Depuis mes débuts dans le domaine de l'apprentissage automatique, j'ai exploré de nombreux frameworks pour développer des modèles de deep learning. Ces derniers temps, j'ai choisi de me concentrer sur deux outils en particulier : JAX et Flax. Dans cet article, je vous expliquerai pourquoi cette combinaison est devenue ma préférée, et comment elle se distingue des autres frameworks comme TensorFlow ou PyTorch.
Les frameworks traditionnels comme PyTorch ou TensorFlow dominent le paysage du Deep Learning, mais JAX et Flax se démarquent par leur vitesse, flexibilité et contrôle fin des calculs. Conçus pour exploiter pleinement les GPU/TPU, ils sont idéaux pour la recherche et les applications gourmandes en ressources. Cet article détaille leurs fonctionnalités clés, leurs différences avec NumPy, et comment contourner leurs limites.
JAX est un paquet python pour écrire des transformations numériques composables. Il tire parti d'Autograd et de XLA (Accelerated Linear Algebra) pour réaliser des calculs numériques à haute performance, ce qui est particulièrement pertinent dans le machine learning.Il peut être considéré comme une version accélérée de NumPy et est couramment utilisé avec des bibliothèques de réseaux de neurones de plus haut niveau telles que Flax.
JAX et NumPy permettent tous deux d'effectuer des opérations basées sur des tableaux, en utilisant une interface similaire. Cela signifie que vous pouvez exécuter bon nombre de vos opérations NumPy préférées dans JAX, en utilisant une API similaire.
Exemple:
import jax.numpy as jnp
# Addition élément par élément
jax_array1 + jax_array2
# Multiplication matricielle
jnp.dot(matrix1, matrix2)
# Fonctions mathématiques
jnp.sin(x)
Bien que JAX et NumPy présentent certaines similitudes, ils ont également des différences importantes. Nous aborderons ici quelques différences d'API.
Les tableaux JAX sont immuables, ce qui signifie que vous ne pouvez pas modifier les valeurs d'un tableau une fois qu'il a été créé. Cela diffère de NumPy, où vous pouvez modifier les valeurs d'un tableau en place.
# NumPy (modifiable)
np_array[0] = 10
# JAX (immutable)
new_array = jax_array.at[0].set(10)
JAX exige que vous soyez plus explicite lors de la génération de nombres aléatoires. Vous devez passer une clé (key) à chaque fois que vous appelez une fonction qui a un certain degré d'aléatoire.
Dans NumPy :
np.random.seed(0)
valeurs = np.random.rand(3)
print(valeurs) # [0.5488135 0.71518937 0.60276338]
Dans JAX (avec gestion explicite des clés) :
from jax import random
key = random.PRNGKey(0)
valeurs = random.uniform(key, (3,))
print(valeurs) # [0.5488135 0.71518937 0.60276338]
# Pour générer plusieurs ensembles
key, subkey = random.split(key)
nouvelles_valeurs = random.uniform(subkey, (3,))
jit (Compilation Juste-à-Temps) - compile et met en cache les fonctions Python JAX afin qu'elles puissent être exécutées efficacement sur XLA pour accélérer les appels de fonctions. jit prend en entrée une fonction Python et renvoie une version compilée de cette fonction.
from jax import jit
@jit
def add(a, b):
return a + b
result = add(jnp.array([1,2]), jnp.array([3,4]))
Il existe des règles concernant les types de fonctions pouvant être compilées avec jit. Par exemple, les fonctions qui contiennent un flux de contrôle Python (comme des boucles ou des conditions) peuvent ne pas être compatibles avec jit.
grad est utilisé pour calculer automatiquement le gradient d'une fonction dans JAX. Il peut être appliqué aux fonctions Python et NumPy, ce qui signifie que vous pouvez différencier à travers des boucles, des branches, des récursions et des fermetures.
from jax import grad
def square(x):
return x**2
gradient = grad(square)(3.0) # Renvoie 6.0
vmap (Vectorizing map) vectorise automatiquement vos fonctions python. Cela signifie que vous pouvez écrire une fonction qui opère sur des exemples individuels, et vmap l'appliquera automatiquement à un lot d'exemples.
from jax import vmap
batched_square = vmap(square)
result = batched_square(jnp.array([1,2,3])) # [1, 4, 9]
Sans vmap, vous devriez écrire une boucle pour appliquer la fonction à chaque élément du lot. vmap vectorise automatiquement la fonction, la rendant plus efficace et concise.
JAX est différent des frameworks tels que PyTorch ou Tensorflow (TF). Il est plus bas niveau et minimaliste. JAX offre simplement un ensemble de primitives (opérations simples) comme jit et vmap, et s'appuie sur d'autres bibliothèques pour d'autres choses, par exemple en utilisant le chargeur de données de PyTorch ou TF. En raison de la simplicité de JAX, il est couramment utilisé avec des bibliothèques de réseaux de neurones de plus haut niveau telles que Haiku ou Flax.
Flax simplifie la création et l'entraînement de modèles en combinant la puissance de JAX avec une structure orientée objet. Ses composants principaux sont :
Les couches neurales sont définies via des modules immuables (dataclasses). Exemple :
import flax.linen as nn
class SimpleClassifier(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x
Les paramètres (poids, biais) sont stockés dans des Pytrees, des structures arborescentes imbriquées. JAX exige une initialisation explicite des modèles :
model = SimpleClassifier()
# Initialisation avec une clé aléatoire JAX
params = model.init(jax.random.PRNGKey(0), x_sample)
# Inférence
predictions = model.apply(params, x_new)
La fonction init
génère un Pytree contenant tous les paramètres, tandis que apply
exécute le modèle avec ces paramètres.
Optax est une bibliothèque dédiée à l'optimisation dans JAX. Elle permet de :
SGD
, Adam
) :import optax
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
optax.apply_updates
:def update_params(params, gradients, opt_state):
updates, new_opt_state = optimizer.update(gradients, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state
Optax gère les gradients et les mises à jour de manière fonctionnelle, sans effets de bord.
Pour simplifier l'entraînement, Flax fournit TrainState
, un conteneur immutable regroupant :
apply
pour l'inférencefrom flax.training import train_state
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer
)
Exemple de fonction d'entraînement compilée avec jax.jit
:
@jax.jit
def train_step(state, batch):
def loss_fn(params):
preds = state.apply_fn(params, batch['inputs'])
loss = jnp.mean((preds - batch['labels']) ** 2)
return loss
# Calcul des gradients et mise à jour
grad_fn = jax.grad(loss_fn)
grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
# Boucle d'entraînement
for epoch in range(num_epochs):
state = train_step(state, next_batch)
Travailler avec JAX et Flax n’est pas toujours simple – les ressources sont encore rares, et la communauté, bien que dynamique, reste petite. Mais c’est justement ce qui rend l’aventure passionnante. En recodant des architectures PyTorch ou TensorFlow "à la main" avec ces outils, j’ai dû comprendre ce qui se cache sous la couche d’abstraction des frameworks traditionnels : gestion des clés aléatoires, vectorisation explicite, gradients manuels… Ces contraintes m’ont forcée à maîtriser des concepts souvent occultés par la simplicité de torch.nn ou tf.keras. Et finalement, cette rigueur paye : les principes de différentiation automatique, de compilation XLA, ou de gestion fonctionnelle des états sont des compétences qui dépassent largement JAX. Aujourd’hui, même quand je reviens à PyTorch, je code mieux – en comprenant *pourquoi* les choses fonctionnent, pas seulement *comment*. C’est ça, la magie de ce duo : il vous rend meilleur, même quand vous ne l’utilisez plus.
Pour approfondir vos connaissances et explorer des outils avancés, voici quelques ressources incontournables :
Ma recommandation musicale du jour : à écouter sans modération !
Écouter sur YouTube