Après une longue absence, il est temps de plonger dans de nouveaux partages et découvertes. Aujourd'hui, je souhaite commencer une série d'articles en explorant les différentes couches (layers) qui composent les réseaux de neurones, avec un focus sur leur rôle, leur fonctionnement, et leur implémentation en JAX.
Cet article représente la première partie de cette exploration. Nous allons découvrir les briques fondamentales des réseaux neuronaux, en expliquant clairement à quoi servent ces différentes couches tout en fournissant des exemples de code simples et compréhensibles.
Les fonctions d'activation sont essentielles dans les réseaux neuronaux car elles introduisent de la non-linéarité. Sans elles, chaque couche du réseau serait simplement une transformation linéaire, et toutes les couches combinées ne formeraient qu'une seule transformation linéaire globale. Cela limiterait considérablement la capacité du modèle à apprendre des relations complexes entre les données.
La ReLU laisse passer les valeurs positives telles quelles et transforme toutes les valeurs négatives en zéro. Cette simplicité permet au réseau de capturer des relations non-linéaires sans alourdir les calculs.
def relu(input):
return jnp.maximum(0, input)
Rôle : Elle est rapide, efficace et empêche le réseau de "s'enliser" lors de l'apprentissage.
Le Softmax prend un ensemble de nombres et les transforme en probabilités. Bien qu'il soit principalement utilisé pour la classification multiclasse, il introduit également une forme de non-linéarité en normalisant les scores prédits.
def softmax(x, axis=-1):
x_max = jnp.max(x, axis=axis, keepdims=True)
x_shifted = x - x_max
exp_x = jnp.exp(x_shifted)
return exp_x / jnp.sum(exp_x, axis=axis, keepdims=True)
Rôle : Convertit des scores bruts en probabilités faciles à interpréter.
La Sigmoid transforme n'importe quelle valeur en un nombre compris entre 0 et 1. Elle est idéale pour les problèmes où il n'y a que deux réponses possibles.
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
Rôle : Parfait pour les problèmes binaires, comme détecter si une photo contient un chat ou non.
Maintenant que nous avons vu comment les fonctions d'activation introduisent de la non-linéarité, parlons des différentes "couches" qui forment les réseaux neuronaux. Chaque couche a un rôle spécifique.
Une couche dense relie chaque neurone de la couche précédente à tous ceux de la suivante. Elle effectue une transformation linéaire suivie d'une activation non-linéaire.
def initialize_dense_layer(key, input_dim, output_dim):
w_key, b_key = random.split(key)
limit = jnp.sqrt(6.0 / (input_dim + output_dim))
w = random.uniform(w_key, (input_dim, output_dim), minval=-limit, maxval=limit)
b = random.uniform(b_key, (output_dim,), minval=-limit, maxval=limit)
return w, b
def dense_layer(params, x):
w, b = params
return jnp.dot(x, w) + b
Rôle : Polyvalente et utilisée dans presque tous les types de réseaux neuronaux.
Cette couche s'assure que les données restent bien équilibrées pendant l'apprentissage.
def initialize_layer_norm(hidden_dim):
return jnp.ones(hidden_dim), jnp.zeros(hidden_dim)
def layer_norm(x, layernorm_params):
gamma, beta = layernorm_params
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.var(x, axis=-1, keepdims=True)
return gamma * (x - mean) / jnp.sqrt(var + 1e-6) + beta
Rôle : Aide le réseau à converger plus rapidement et à éviter des erreurs lors de l'apprentissage.
Un MLP est une série de couches fully connected alternées avec des fonctions d'activation.
def initialize_mlp(hidden_dim, mlp_dim, key):
w1_key, w2_key = random.split(key)
limit = jnp.sqrt(6.0 / (hidden_dim + mlp_dim))
w1 = random.uniform(w1_key, (hidden_dim, mlp_dim), minval=-limit, maxval=limit)
b1 = jnp.zeros(mlp_dim)
w2 = random.uniform(w2_key, (mlp_dim, hidden_dim), minval=-limit, maxval=limit)
b2 = jnp.zeros(hidden_dim)
return w1, b1, w2, b2
def mlp(x, mlp_params):
w1, b1, w2, b2 = mlp_params
up_proj = relu(jnp.matmul(x, w1) + b1)
down_proj = jnp.matmul(up_proj, w2) + b2
return down_proj
Rôle : Modélise des problèmes très variés grâce à ses multiples couches.
L'attention auto permet au modèle de se concentrer sur les parties les plus importantes de son entrée.
head_dim = 64
num_heads = 4
def initialize_attention(hidden_dim, num_heads, head_dim, key):
q_key, k_key, v_key = random.split(key, 3)
fan_in = hidden_dim
fan_out = head_dim * num_heads
limit = jnp.sqrt(6.0 / (fan_in + fan_out))
q_w = random.uniform(q_key, (fan_in, fan_out), minval=-limit, maxval=limit)
q_b = jnp.zeros(fan_out)
k_w = random.uniform(k_key, (fan_in, fan_out), minval=-limit, maxval=limit)
k_b = jnp.zeros(fan_out)
v_w = random.uniform(v_key, (fan_in, fan_out), minval=-limit, maxval=limit)
v_b = jnp.zeros(fan_out)
return q_w, k_w, v_w, q_b, k_b, v_b
def self_attention(x, attn_params):
q_w, k_w, v_w, q_b, k_b, v_b = attn_params
n, d_k = x.shape
q = jnp.matmul(x, q_w) + q_b
k = jnp.matmul(x, k_w) + k_b
v = jnp.matmul(x, v_w) + v_b
q = q.reshape(n, num_heads, head_dim).swapaxes(0, 1)
k = k.reshape(n, num_heads, head_dim).swapaxes(0, 1)
v = v.reshape(n, num_heads, head_dim).swapaxes(0, 1)
attention_weights_heads = jnp.matmul(q, jnp.swapaxes(k, -1, -2)) / jnp.sqrt(head_dim)
attention_weights_heads = jax.nn.softmax(attention_weights_heads, axis=-1)
output = jnp.matmul(attention_weights_heads, v)
return output.swapaxes(0, 1).reshape(n, d_k)
Rôle : Traite efficacement des séquences, comme les phrases ou les vidéos.
Les embeddings transforment des informations discrètes, comme des mots, en vecteurs numériques.
def initialize_embedding(key, vocab_size, hidden_dim):
limit = jnp.sqrt(6.0 / (vocab_size + hidden_dim))
w = random.uniform(key, (vocab_size, hidden_dim), minval=-limit, maxval=limit)
return w
def embedding(x, embedding_params):
return embedding_params[x]
Rôle : Rend les données discrètes, comme les mots, exploitables par le réseau neuronal.
Dropout désactive aléatoirement certains neurones pendant l'apprentissage pour éviter le surapprentissage.
def dropout(key, x, rate, in_train_mode=True):
if in_train_mode:
mask = random.bernoulli(key, 1 - rate, x.shape)
return x * mask / (1.0 - rate)
return x
Rôle : Rend le modèle plus robuste et généralisable.
Cette couche normalise les données à chaque étape de l'apprentissage.
def initialize_batch_norm(hidden_dim):
return (
jnp.ones(hidden_dim),
jnp.zeros(hidden_dim),
jnp.zeros(hidden_dim),
jnp.ones(hidden_dim)
)
def batch_norm(params, inputs, train_mode=True, epsilon=1e-6, momentum=0.9):
gamma, beta, running_mean, running_var = params
if train_mode:
mean = jnp.mean(inputs, axis=0)
var = jnp.var(inputs, axis=0)
running_mean = momentum * running_mean + (1.0 - momentum) * mean
running_var = momentum * running_var + (1.0 - momentum) * var
x_hat = (inputs - mean) / jnp.sqrt(var + epsilon)
return gamma * x_hat + beta
else:
x_hat = (inputs - running_mean) / jnp.sqrt(running_var + epsilon)
return gamma * x_hat + beta
Rôle : Améliore la stabilité et la vitesse d'apprentissage.
Nous avons parcouru ensemble les principales couches utilisées dans les réseaux neuronaux, en mettant l'accent sur leurs rôles respectifs, leurs avantages, et leur implémentation en JAX. Chaque couche joue un rôle essentiel dans la construction de modèles performants et robustes, mais c'est surtout grâce aux fonctions d'activation que ces modèles peuvent capturer des relations complexes.
Cet article marque la première partie de notre exploration. Dans les prochains articles, nous verrons comment ces concepts peuvent être combinés pour résoudre des problèmes concrets, comme la reconnaissance d'images ou la traduction automatique.
Ma recommandation musicale du jour : à écouter sans modération !
Écouter sur YouTube