Jonathan Suru

Revenons ensemble vers l'essentiel

Après une longue absence, il est temps de plonger dans de nouveaux partages et découvertes. Aujourd'hui, je souhaite commencer une série d'articles en explorant les différentes couches (layers) qui composent les réseaux de neurones, avec un focus sur leur rôle, leur fonctionnement, et leur implémentation en JAX.

Cet article représente la première partie de cette exploration. Nous allons découvrir les briques fondamentales des réseaux neuronaux, en expliquant clairement à quoi servent ces différentes couches tout en fournissant des exemples de code simples et compréhensibles.

Les Fonctions d'Activation : Les Interrupteurs du Réseau

Les fonctions d'activation sont essentielles dans les réseaux neuronaux car elles introduisent de la non-linéarité. Sans elles, chaque couche du réseau serait simplement une transformation linéaire, et toutes les couches combinées ne formeraient qu'une seule transformation linéaire globale. Cela limiterait considérablement la capacité du modèle à apprendre des relations complexes entre les données.

ReLU (Rectified Linear Unit)

La ReLU laisse passer les valeurs positives telles quelles et transforme toutes les valeurs négatives en zéro. Cette simplicité permet au réseau de capturer des relations non-linéaires sans alourdir les calculs.


def relu(input):
    return jnp.maximum(0, input)
    

Rôle : Elle est rapide, efficace et empêche le réseau de "s'enliser" lors de l'apprentissage.

Softmax

Le Softmax prend un ensemble de nombres et les transforme en probabilités. Bien qu'il soit principalement utilisé pour la classification multiclasse, il introduit également une forme de non-linéarité en normalisant les scores prédits.


def softmax(x, axis=-1):
    x_max = jnp.max(x, axis=axis, keepdims=True)
    x_shifted = x - x_max
    exp_x = jnp.exp(x_shifted)
    return exp_x / jnp.sum(exp_x, axis=axis, keepdims=True)
    

Rôle : Convertit des scores bruts en probabilités faciles à interpréter.

Sigmoid

La Sigmoid transforme n'importe quelle valeur en un nombre compris entre 0 et 1. Elle est idéale pour les problèmes où il n'y a que deux réponses possibles.


def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))
    

Rôle : Parfait pour les problèmes binaires, comme détecter si une photo contient un chat ou non.

Les Couches Fondamentales : Construire un Réseau Neuronal

Maintenant que nous avons vu comment les fonctions d'activation introduisent de la non-linéarité, parlons des différentes "couches" qui forment les réseaux neuronaux. Chaque couche a un rôle spécifique.

Dense Layer (Couche Dense)

Une couche dense relie chaque neurone de la couche précédente à tous ceux de la suivante. Elle effectue une transformation linéaire suivie d'une activation non-linéaire.


def initialize_dense_layer(key, input_dim, output_dim):
    w_key, b_key = random.split(key)
    limit = jnp.sqrt(6.0 / (input_dim + output_dim))
    w = random.uniform(w_key, (input_dim, output_dim), minval=-limit, maxval=limit)
    b = random.uniform(b_key, (output_dim,), minval=-limit, maxval=limit)
    return w, b

def dense_layer(params, x):
    w, b = params
    return jnp.dot(x, w) + b
    

Rôle : Polyvalente et utilisée dans presque tous les types de réseaux neuronaux.

Layer Normalization (Normalisation par Couche)

Cette couche s'assure que les données restent bien équilibrées pendant l'apprentissage.


def initialize_layer_norm(hidden_dim):
    return jnp.ones(hidden_dim), jnp.zeros(hidden_dim)

def layer_norm(x, layernorm_params):
    gamma, beta = layernorm_params
    mean = jnp.mean(x, axis=-1, keepdims=True)
    var = jnp.var(x, axis=-1, keepdims=True)
    return gamma * (x - mean) / jnp.sqrt(var + 1e-6) + beta
    

Rôle : Aide le réseau à converger plus rapidement et à éviter des erreurs lors de l'apprentissage.

MLP (Multilayer Perceptron)

Un MLP est une série de couches fully connected alternées avec des fonctions d'activation.


def initialize_mlp(hidden_dim, mlp_dim, key):
    w1_key, w2_key = random.split(key)
    limit = jnp.sqrt(6.0 / (hidden_dim + mlp_dim))
    w1 = random.uniform(w1_key, (hidden_dim, mlp_dim), minval=-limit, maxval=limit)
    b1 = jnp.zeros(mlp_dim)
    w2 = random.uniform(w2_key, (mlp_dim, hidden_dim), minval=-limit, maxval=limit)
    b2 = jnp.zeros(hidden_dim)
    return w1, b1, w2, b2

def mlp(x, mlp_params):
    w1, b1, w2, b2 = mlp_params
    up_proj = relu(jnp.matmul(x, w1) + b1)
    down_proj = jnp.matmul(up_proj, w2) + b2
    return down_proj
    

Rôle : Modélise des problèmes très variés grâce à ses multiples couches.

Self-Attention (Attention Auto)

L'attention auto permet au modèle de se concentrer sur les parties les plus importantes de son entrée.


head_dim = 64
num_heads = 4

def initialize_attention(hidden_dim, num_heads, head_dim, key):
    q_key, k_key, v_key = random.split(key, 3)
    fan_in = hidden_dim
    fan_out = head_dim * num_heads
    limit = jnp.sqrt(6.0 / (fan_in + fan_out))
    q_w = random.uniform(q_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    q_b = jnp.zeros(fan_out)
    k_w = random.uniform(k_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    k_b = jnp.zeros(fan_out)
    v_w = random.uniform(v_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    v_b = jnp.zeros(fan_out)
    return q_w, k_w, v_w, q_b, k_b, v_b

def self_attention(x, attn_params):
    q_w, k_w, v_w, q_b, k_b, v_b = attn_params
    n, d_k = x.shape
    q = jnp.matmul(x, q_w) + q_b
    k = jnp.matmul(x, k_w) + k_b
    v = jnp.matmul(x, v_w) + v_b
    q = q.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    k = k.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    v = v.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    attention_weights_heads = jnp.matmul(q, jnp.swapaxes(k, -1, -2)) / jnp.sqrt(head_dim)
    attention_weights_heads = jax.nn.softmax(attention_weights_heads, axis=-1)
    output = jnp.matmul(attention_weights_heads, v)
    return output.swapaxes(0, 1).reshape(n, d_k)
    

Rôle : Traite efficacement des séquences, comme les phrases ou les vidéos.

Embedding

Les embeddings transforment des informations discrètes, comme des mots, en vecteurs numériques.


def initialize_embedding(key, vocab_size, hidden_dim):
    limit = jnp.sqrt(6.0 / (vocab_size + hidden_dim))
    w = random.uniform(key, (vocab_size, hidden_dim), minval=-limit, maxval=limit)
    return w

def embedding(x, embedding_params):
    return embedding_params[x]
    

Rôle : Rend les données discrètes, comme les mots, exploitables par le réseau neuronal.

Dropout

Dropout désactive aléatoirement certains neurones pendant l'apprentissage pour éviter le surapprentissage.


def dropout(key, x, rate, in_train_mode=True):
    if in_train_mode:
        mask = random.bernoulli(key, 1 - rate, x.shape)
        return x * mask / (1.0 - rate)
    return x
    

Rôle : Rend le modèle plus robuste et généralisable.

Batch Normalization (Normalisation par Mini-Lot)

Cette couche normalise les données à chaque étape de l'apprentissage.


def initialize_batch_norm(hidden_dim):
    return (
        jnp.ones(hidden_dim),
        jnp.zeros(hidden_dim),
        jnp.zeros(hidden_dim),
        jnp.ones(hidden_dim)
    )

def batch_norm(params, inputs, train_mode=True, epsilon=1e-6, momentum=0.9):
    gamma, beta, running_mean, running_var = params
    if train_mode:
        mean = jnp.mean(inputs, axis=0)
        var = jnp.var(inputs, axis=0)
        running_mean = momentum * running_mean + (1.0 - momentum) * mean
        running_var = momentum * running_var + (1.0 - momentum) * var
        x_hat = (inputs - mean) / jnp.sqrt(var + epsilon)
        return gamma * x_hat + beta
    else:
        x_hat = (inputs - running_mean) / jnp.sqrt(running_var + epsilon)
        return gamma * x_hat + beta
    

Rôle : Améliore la stabilité et la vitesse d'apprentissage.

Conclusion

Nous avons parcouru ensemble les principales couches utilisées dans les réseaux neuronaux, en mettant l'accent sur leurs rôles respectifs, leurs avantages, et leur implémentation en JAX. Chaque couche joue un rôle essentiel dans la construction de modèles performants et robustes, mais c'est surtout grâce aux fonctions d'activation que ces modèles peuvent capturer des relations complexes.

Cet article marque la première partie de notre exploration. Dans les prochains articles, nous verrons comment ces concepts peuvent être combinés pour résoudre des problèmes concrets, comme la reconnaissance d'images ou la traduction automatique.

Ma recommandation musicale du jour : à écouter sans modération !

Écouter sur YouTube